Skip to content

Commit 41fce54

Browse files
author
Davies Liu
committed
randomSplit()
1 parent c6f4e70 commit 41fce54

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,19 @@ private[spark] object PythonRDD extends Logging {
757757
converted.saveAsHadoopDataset(new JobConf(conf))
758758
}
759759
}
760+
761+
/**
762+
* A helper to convert java.util.List[Double] into Array[Double]
763+
* @param list
764+
* @return
765+
*/
766+
def listToArrayDouble(list: JList[Double]): Array[Double] = {
767+
val r = new Array[Double](list.size)
768+
list.zipWithIndex.foreach {
769+
case (v, i) => r(i) = v
770+
}
771+
r
772+
}
760773
}
761774

762775
private

python/pyspark/rdd.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,34 @@ def sample(self, withReplacement, fraction, seed=None):
316316
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
317317
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
318318

319+
def randomSplit(self, weights, seed=None):
320+
"""
321+
Randomly splits this RDD with the provided weights.
322+
323+
:param weights: weights for splits, will be normalized if they don't sum to 1
324+
:param seed: random seed
325+
:return: split RDDs in an list
326+
327+
>>> rdd = sc.parallelize(range(10), 1)
328+
>>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11)
329+
>>> rdd1.collect()
330+
[3, 6]
331+
>>> rdd2.collect()
332+
[0, 5, 7]
333+
>>> rdd3.collect()
334+
[1, 2, 4, 8, 9]
335+
"""
336+
ser = BatchedSerializer(PickleSerializer(), 1)
337+
rdd = self._reserialize(ser)
338+
jweights = ListConverter().convert([float(w) for w in weights],
339+
self.ctx._gateway._gateway_client)
340+
jweights = self.ctx._jvm.PythonRDD.listToArrayDouble(jweights)
341+
if seed is None:
342+
jrdds = rdd._jrdd.randomSplit(jweights)
343+
else:
344+
jrdds = rdd._jrdd.randomSplit(jweights, seed)
345+
return [RDD(jrdd, self.ctx, ser) for jrdd in jrdds]
346+
319347
# this is ported from scala/spark/RDD.scala
320348
def takeSample(self, withReplacement, num, seed=None):
321349
"""

0 commit comments

Comments
 (0)