Skip to content

Commit c7a2007

Browse files
author
Davies Liu
committed
switch to python implementation
1 parent 95a48ac commit c7a2007

File tree

3 files changed

+16
-26
lines changed

3 files changed

+16
-26
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -758,13 +758,6 @@ private[spark] object PythonRDD extends Logging {
758758
converted.saveAsHadoopDataset(new JobConf(conf))
759759
}
760760
}
761-
762-
/**
763-
* A helper to convert java.util.List[Double] into Array[Double]
764-
*/
765-
def listToArrayDouble(list: JList[Double]): Array[Double] = {
766-
list.asScala.toArray
767-
}
768761
}
769762

770763
private

python/pyspark/rdd.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import warnings
2929
import heapq
3030
import bisect
31-
from random import Random
31+
import random
3232
from math import sqrt, log, isinf, isnan
3333

3434
from pyspark.accumulators import PStatsParam
@@ -324,25 +324,21 @@ def randomSplit(self, weights, seed=None):
324324
:param seed: random seed
325325
:return: split RDDs in a list
326326
327-
>>> rdd = sc.parallelize(range(10), 1)
328-
>>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11)
327+
>>> rdd = sc.parallelize(range(5), 1)
328+
>>> rdd1, rdd2 = rdd.randomSplit([2.0, 3.0], 101)
329329
>>> rdd1.collect()
330-
[3, 6]
330+
[2, 3]
331331
>>> rdd2.collect()
332-
[0, 5, 7]
333-
>>> rdd3.collect()
334-
[1, 2, 4, 8, 9]
332+
[0, 1, 4]
335333
"""
336-
ser = BatchedSerializer(PickleSerializer(), 1)
337-
rdd = self._reserialize(ser)
338-
jweights = ListConverter().convert([float(w) for w in weights],
339-
self.ctx._gateway._gateway_client)
340-
jweights = self.ctx._jvm.PythonRDD.listToArrayDouble(jweights)
334+
s = sum(weights)
335+
cweights = [0.0]
336+
for w in weights:
337+
cweights.append(cweights[-1] + w / s)
341338
if seed is None:
342-
jrdds = rdd._jrdd.randomSplit(jweights)
343-
else:
344-
jrdds = rdd._jrdd.randomSplit(jweights, seed)
345-
return [RDD(jrdd, self.ctx, ser) for jrdd in jrdds]
339+
seed = random.randint(0, 2 ** 32 - 1)
340+
return [self.mapPartitionsWithIndex(RDDSampler(False, ub, seed, lb).func, True)
341+
for lb, ub in zip(cweights, cweights[1:])]
346342

347343
# this is ported from scala/spark/RDD.scala
348344
def takeSample(self, withReplacement, num, seed=None):
@@ -369,7 +365,7 @@ def takeSample(self, withReplacement, num, seed=None):
369365
if initialCount == 0:
370366
return []
371367

372-
rand = Random(seed)
368+
rand = random.Random(seed)
373369

374370
if (not withReplacement) and num >= initialCount:
375371
# shuffle current RDD and return

python/pyspark/rddsampler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,10 @@ def shuffle(self, vals):
9696

9797
class RDDSampler(RDDSamplerBase):
9898

99-
def __init__(self, withReplacement, fraction, seed=None):
99+
def __init__(self, withReplacement, fraction, seed=None, lowbound=0.0):
100100
RDDSamplerBase.__init__(self, withReplacement, seed)
101101
self._fraction = fraction
102+
self._lowbound = lowbound
102103

103104
def func(self, split, iterator):
104105
if self._withReplacement:
@@ -111,7 +112,7 @@ def func(self, split, iterator):
111112
yield obj
112113
else:
113114
for obj in iterator:
114-
if self.getUniformSample(split) <= self._fraction:
115+
if self._lowbound <= self.getUniformSample(split) < self._fraction:
115116
yield obj
116117

117118

0 commit comments

Comments
 (0)