Skip to content

Commit ab96fc3

Browse files
committed
Merge pull request #14 from markhamstra/streaming-iterable
SKIPME Streaming iterable
2 parents 1980448 + 4f51bdf commit ab96fc3

File tree

3 files changed

+121
-0
lines changed

3 files changed

+121
-0
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableHyperLogL
4545

4646
import org.apache.spark.SparkContext._
4747
import org.apache.spark._
48+
import scala.concurrent.duration.Duration
49+
import java.util.concurrent.TimeUnit
4850

4951
/**
5052
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -621,6 +623,19 @@ abstract class RDD[T: ClassTag](
621623
filter(f.isDefinedAt).map(f)
622624
}
623625

626+
/**
627+
* Return iterator that lazily fetches partitions
628+
* @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism
629+
* but also increases driver memory requirement.
630+
* @param partitionBatchSize How many partitions fetch per job
631+
* @param timeOut how long to wait for each partition fetch
632+
* @return Iterable of every element in this RDD
633+
*/
634+
def toIterator(prefetchPartitions: Int = 1, partitionBatchSize: Int = 10,
635+
timeOut: Duration = Duration(30, TimeUnit.SECONDS)):Iterator[T] = {
636+
new RDDiterator[T](this, prefetchPartitions,partitionBatchSize, timeOut)
637+
}
638+
624639
/**
625640
* Return an RDD with the elements from `this` that are not in `other`.
626641
*
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package org.apache.spark.rdd
2+
3+
import scala.concurrent.{Await, Future}
4+
import scala.collection.mutable.ArrayBuffer
5+
import scala.concurrent.duration.Duration
6+
import scala.annotation.tailrec
7+
import scala.collection.mutable
8+
import org.apache.spark.rdd.RDDiterator._
9+
import org.apache.spark.FutureAction
10+
11+
/**
12+
* Iterable whose iterator iterates over all elements of an RDD without fetching all partitions
13+
* to the driver process
14+
*
15+
* @param rdd RDD to iterate
16+
* @param prefetchPartitions The number of partitions to prefetch.
17+
* If <1 will not prefetch.
18+
* partitions prefetched = min(prefetchPartitions, partitionBatchSize)
19+
* @param partitionBatchSize How many partitions to fetch per job
20+
* @param timeOut How long to wait for each partition before failing.
21+
*/
22+
class RDDiterator[T: ClassManifest](rdd: RDD[T], prefetchPartitions: Int, partitionBatchSize: Int,
23+
timeOut: Duration)
24+
extends Iterator[T] {
25+
26+
val batchSize = math.max(1,partitionBatchSize)
27+
var partitionsBatches: Iterator[Seq[Int]] = Range(0, rdd.partitions.size).grouped(batchSize)
28+
var pendingFetchesQueue = mutable.Queue.empty[Future[Seq[Seq[T]]]]
29+
//add prefetchPartitions prefetch
30+
0.until(math.max(0, prefetchPartitions / batchSize)).foreach(x=>enqueueDataFetch())
31+
32+
var currentIterator: Iterator[T] = Iterator.empty
33+
@tailrec
34+
final def hasNext = {
35+
if (currentIterator.hasNext) {
36+
//Still values in the current partition
37+
true
38+
} else {
39+
//Move on to the next partition
40+
//Queue new prefetch of a partition
41+
enqueueDataFetch()
42+
if (pendingFetchesQueue.isEmpty) {
43+
//No more partitions
44+
currentIterator = Iterator.empty
45+
false
46+
} else {
47+
val future = pendingFetchesQueue.dequeue()
48+
currentIterator = Await.result(future, timeOut).flatMap(x => x).iterator
49+
//Next partition might be empty so check again.
50+
this.hasNext
51+
}
52+
}
53+
}
54+
def next() = {
55+
hasNext
56+
currentIterator.next()
57+
}
58+
59+
def enqueueDataFetch() ={
60+
if (partitionsBatches.hasNext) {
61+
pendingFetchesQueue.enqueue(fetchData(partitionsBatches.next(), rdd))
62+
}
63+
}
64+
}
65+
66+
object RDDiterator {
67+
private def fetchData[T: ClassManifest](partitionIndexes: Seq[Int],
68+
rdd: RDD[T]): FutureAction[Seq[Seq[T]]] = {
69+
val results = new ArrayBuffer[Seq[T]]()
70+
rdd.context.submitJob[T, Array[T], Seq[Seq[T]]](rdd,
71+
x => x.toArray,
72+
partitionIndexes,
73+
(inx: Int, res: Array[T]) => results.append(res),
74+
results.toSeq)
75+
}
76+
}

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,36 @@ class RDDSuite extends FunSuite with SharedSparkContext {
381381
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
382382
}
383383

384+
test("toIterable") {
385+
var nums = sc.makeRDD(Range(1, 1000), 100)
386+
assert(nums.toIterator(prefetchPartitions = 10).size === 999)
387+
assert(nums.toIterator().toArray === (1 to 999).toArray)
388+
389+
nums = sc.makeRDD(Range(1000, 1, -1), 100)
390+
assert(nums.toIterator(prefetchPartitions = 10).size === 999)
391+
assert(nums.toIterator(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray)
392+
393+
nums = sc.makeRDD(Range(1, 100), 1000)
394+
assert(nums.toIterator(prefetchPartitions = 10).size === 99)
395+
assert(nums.toIterator(prefetchPartitions = 10).toArray === Range(1, 100).toArray)
396+
397+
nums = sc.makeRDD(Range(1, 1000), 100)
398+
assert(nums.toIterator(prefetchPartitions = -1).size === 999)
399+
assert(nums.toIterator().toArray === (1 to 999).toArray)
400+
401+
nums = sc.makeRDD(Range(1, 1000), 100)
402+
assert(nums.toIterator(prefetchPartitions = 3,partitionBatchSize = 10).size === 999)
403+
assert(nums.toIterator().toArray === (1 to 999).toArray)
404+
405+
nums = sc.makeRDD(Range(1, 1000), 100)
406+
assert(nums.toIterator(prefetchPartitions = -1,partitionBatchSize = 0).size === 999)
407+
assert(nums.toIterator().toArray === (1 to 999).toArray)
408+
409+
nums = sc.makeRDD(Range(1, 1000), 100)
410+
assert(nums.toIterator(prefetchPartitions = -1).size === 999)
411+
assert(nums.toIterator().toArray === (1 to 999).toArray)
412+
}
413+
384414
test("take") {
385415
var nums = sc.makeRDD(Range(1, 1000), 1)
386416
assert(nums.take(0).size === 0)

0 commit comments

Comments
 (0)