Skip to content

Commit 12280b5

Browse files
committed
Merge pull request #6 from markhamstra/streamingIterable
SPY-287 updated streaming iterable
2 parents 397f05f + ede65ec commit 12280b5

File tree

4 files changed

+103
-76
lines changed

4 files changed

+103
-76
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.apache.spark.partial.CountEvaluator
3838
import org.apache.spark.partial.GroupedCountEvaluator
3939
import org.apache.spark.partial.PartialResult
4040
import org.apache.spark.storage.StorageLevel
41-
import org.apache.spark.util.{RDDiterable, Utils, BoundedPriorityQueue}
41+
import org.apache.spark.util.{Utils, BoundedPriorityQueue}
4242

4343
import org.apache.spark.SparkContext._
4444
import org.apache.spark._
@@ -576,8 +576,6 @@ abstract class RDD[T: ClassManifest](
576576
sc.runJob(this, (iter: Iterator[T]) => f(iter))
577577
}
578578

579-
580-
581579
/**
582580
* Return an array that contains all of the elements in this RDD.
583581
*/
@@ -599,14 +597,16 @@ abstract class RDD[T: ClassManifest](
599597
}
600598

601599
/**
602-
* Return iterable that lazily fetches partitions
603-
* @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism but also increases
604-
* driver memory requirement
600+
* Return iterator that lazily fetches partitions
601+
* @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism
602+
* but also increases driver memory requirement.
603+
* @param partitionBatchSize How many partitions fetch per job
605604
* @param timeOut how long to wait for each partition fetch
606605
* @return Iterable of every element in this RDD
607606
*/
608-
def toIterable(prefetchPartitions: Int = 1, timeOut: Duration = Duration(30, TimeUnit.SECONDS)) = {
609-
new RDDiterable[T](this, prefetchPartitions, timeOut)
607+
def toIterator(prefetchPartitions: Int = 1, partitionBatchSize: Int = 10,
608+
timeOut: Duration = Duration(30, TimeUnit.SECONDS)):Iterator[T] = {
609+
new RDDiterator[T](this, prefetchPartitions,partitionBatchSize, timeOut)
610610
}
611611

612612
/**
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package org.apache.spark.rdd
2+
3+
import scala.concurrent.{Await, Future}
4+
import scala.collection.mutable.ArrayBuffer
5+
import scala.concurrent.duration.Duration
6+
import scala.annotation.tailrec
7+
import scala.collection.mutable
8+
import org.apache.spark.rdd.RDDiterator._
9+
import org.apache.spark.FutureAction
10+
11+
/**
12+
* Iterable whose iterator iterates over all elements of an RDD without fetching all partitions
13+
* to the driver process
14+
*
15+
* @param rdd RDD to iterate
16+
* @param prefetchPartitions The number of partitions to prefetch.
17+
* If <1 will not prefetch.
18+
* partitions prefetched = min(prefetchPartitions, partitionBatchSize)
19+
* @param partitionBatchSize How many partitions to fetch per job
20+
* @param timeOut How long to wait for each partition before failing.
21+
*/
22+
class RDDiterator[T: ClassManifest](rdd: RDD[T], prefetchPartitions: Int, partitionBatchSize: Int,
23+
timeOut: Duration)
24+
extends Iterator[T] {
25+
26+
val batchSize = math.max(1,partitionBatchSize)
27+
var partitionsBatches: Iterator[Seq[Int]] = Range(0, rdd.partitions.size).grouped(batchSize)
28+
var pendingFetchesQueue = mutable.Queue.empty[Future[Seq[Seq[T]]]]
29+
//add prefetchPartitions prefetch
30+
0.until(math.max(0, prefetchPartitions / batchSize)).foreach(x=>enqueueDataFetch())
31+
32+
var currentIterator: Iterator[T] = Iterator.empty
33+
@tailrec
34+
final def hasNext = {
35+
if (currentIterator.hasNext) {
36+
//Still values in the current partition
37+
true
38+
} else {
39+
//Move on to the next partition
40+
//Queue new prefetch of a partition
41+
enqueueDataFetch()
42+
if (pendingFetchesQueue.isEmpty) {
43+
//No more partitions
44+
currentIterator = Iterator.empty
45+
false
46+
} else {
47+
val future = pendingFetchesQueue.dequeue()
48+
currentIterator = Await.result(future, timeOut).flatMap(x => x).iterator
49+
//Next partition might be empty so check again.
50+
this.hasNext
51+
}
52+
}
53+
}
54+
def next() = {
55+
hasNext
56+
currentIterator.next()
57+
}
58+
59+
def enqueueDataFetch() ={
60+
if (partitionsBatches.hasNext) {
61+
pendingFetchesQueue.enqueue(fetchData(partitionsBatches.next(), rdd))
62+
}
63+
}
64+
}
65+
66+
object RDDiterator {
67+
private def fetchData[T: ClassManifest](partitionIndexes: Seq[Int],
68+
rdd: RDD[T]): FutureAction[Seq[Seq[T]]] = {
69+
val results = new ArrayBuffer[Seq[T]]()
70+
rdd.context.submitJob[T, Array[T], Seq[Seq[T]]](rdd,
71+
x => x.toArray,
72+
partitionIndexes,
73+
(inx: Int, res: Array[T]) => results.append(res),
74+
results.toSeq)
75+
}
76+
}

core/src/main/scala/org/apache/spark/util/RDDiterable.scala

Lines changed: 0 additions & 59 deletions
This file was deleted.

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -342,23 +342,33 @@ class RDDSuite extends FunSuite with SharedSparkContext {
342342

343343
test("toIterable") {
344344
var nums = sc.makeRDD(Range(1, 1000), 100)
345-
assert(nums.toIterable(prefetchPartitions = 10).size === 999)
346-
assert(nums.toIterable().toArray === (1 to 999).toArray)
345+
assert(nums.toIterator(prefetchPartitions = 10).size === 999)
346+
assert(nums.toIterator().toArray === (1 to 999).toArray)
347347

348348
nums = sc.makeRDD(Range(1000, 1, -1), 100)
349-
assert(nums.toIterable(prefetchPartitions = 10).size === 999)
350-
assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray)
349+
assert(nums.toIterator(prefetchPartitions = 10).size === 999)
350+
assert(nums.toIterator(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray)
351351

352352
nums = sc.makeRDD(Range(1, 100), 1000)
353-
assert(nums.toIterable(prefetchPartitions = 10).size === 99)
354-
assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1, 100).toArray)
353+
assert(nums.toIterator(prefetchPartitions = 10).size === 99)
354+
assert(nums.toIterator(prefetchPartitions = 10).toArray === Range(1, 100).toArray)
355355

356356
nums = sc.makeRDD(Range(1, 1000), 100)
357-
assert(nums.toIterable(prefetchPartitions = -1).size === 999)
358-
assert(nums.toIterable().toArray === (1 to 999).toArray)
359-
}
357+
assert(nums.toIterator(prefetchPartitions = -1).size === 999)
358+
assert(nums.toIterator().toArray === (1 to 999).toArray)
359+
360+
nums = sc.makeRDD(Range(1, 1000), 100)
361+
assert(nums.toIterator(prefetchPartitions = 3,partitionBatchSize = 10).size === 999)
362+
assert(nums.toIterator().toArray === (1 to 999).toArray)
360363

364+
nums = sc.makeRDD(Range(1, 1000), 100)
365+
assert(nums.toIterator(prefetchPartitions = -1,partitionBatchSize = 0).size === 999)
366+
assert(nums.toIterator().toArray === (1 to 999).toArray)
361367

368+
nums = sc.makeRDD(Range(1, 1000), 100)
369+
assert(nums.toIterator(prefetchPartitions = -1).size === 999)
370+
assert(nums.toIterator().toArray === (1 to 999).toArray)
371+
}
362372

363373
test("take") {
364374
var nums = sc.makeRDD(Range(1, 1000), 1)

0 commit comments

Comments
 (0)