Skip to content

Commit 7001b51

Browse files
committed
refactor of queueStream()
1 parent 26ea396 commit 7001b51

File tree

2 files changed

+19
-47
lines changed

2 files changed

+19
-47
lines changed

python/pyspark/streaming/context.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _check_serialzers(self, rdds):
184184
# reset them to sc.serializer
185185
rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True)
186186

187-
def queueStream(self, queue, oneAtATime=False, default=None):
187+
def queueStream(self, queue, oneAtATime=True, default=None):
188188
"""
189189
Create an input stream from an queue of RDDs or list. In each batch,
190190
it will process either one or all of the RDDs returned by the queue.
@@ -200,9 +200,12 @@ def queueStream(self, queue, oneAtATime=False, default=None):
200200
self._check_serialzers(rdds)
201201
jrdds = ListConverter().convert([r._jrdd for r in rdds],
202202
SparkContext._gateway._gateway_client)
203-
jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds, oneAtATime,
204-
default and default._jrdd)
205-
return DStream(jdstream.asJavaDStream(), self, rdds[0]._jrdd_deserializer)
203+
queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
204+
if default:
205+
jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
206+
else:
207+
jdstream = self._jssc.queueStream(queue, oneAtATime)
208+
return DStream(jdstream, self, rdds[0]._jrdd_deserializer)
206209

207210
def transform(self, dstreams, transformFunc):
208211
"""

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.streaming.api.python
1919

2020
import java.util.{ArrayList => JArrayList}
21+
import scala.collection.JavaConversions._
2122

2223
import org.apache.spark.rdd.RDD
2324
import org.apache.spark.api.java._
@@ -65,6 +66,16 @@ abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (p
6566
val asJavaDStream = JavaDStream.fromDStream(this)
6667
}
6768

69+
object PythonDStream {
70+
71+
// convert list of RDD into queue of RDDs, for ssc.queueStream()
72+
def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = {
73+
val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]]
74+
rdds.forall(queue.add(_))
75+
queue
76+
}
77+
}
78+
6879
/**
6980
* Transformed DStream in Python.
7081
*
@@ -243,46 +254,4 @@ class PythonForeachDStream(
243254
) {
244255

245256
this.register()
246-
}
247-
248-
249-
/**
250-
* similar to QueueInputStream
251-
*/
252-
class PythonDataInputStream(
253-
ssc_ : JavaStreamingContext,
254-
inputRDDs: JArrayList[JavaRDD[Array[Byte]]],
255-
oneAtAtime: Boolean,
256-
defaultRDD: JavaRDD[Array[Byte]]
257-
) extends InputDStream[Array[Byte]](JavaStreamingContext.toStreamingContext(ssc_)) {
258-
259-
val emptyRDD = if (defaultRDD != null) {
260-
Some(defaultRDD.rdd)
261-
} else {
262-
Some(ssc.sparkContext.emptyRDD[Array[Byte]])
263-
}
264-
265-
def start() {}
266-
267-
def stop() {}
268-
269-
def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
270-
val index = ((validTime - zeroTime) / slideDuration - 1).toInt
271-
if (oneAtAtime) {
272-
if (index == 0) {
273-
val rdds = inputRDDs.toArray.map(_.asInstanceOf[JavaRDD[Array[Byte]]].rdd).toSeq
274-
Some(ssc.sparkContext.union(rdds))
275-
} else {
276-
emptyRDD
277-
}
278-
} else {
279-
if (index < inputRDDs.size()) {
280-
Some(inputRDDs.get(index).rdd)
281-
} else {
282-
emptyRDD
283-
}
284-
}
285-
}
286-
287-
val asJavaDStream = JavaDStream.fromDStream(this)
288-
}
257+
}

0 commit comments

Comments
 (0)