Skip to content

Commit 3f0fb4b

Browse files
committed
refactor fix tests
1 parent c28f520 commit 3f0fb4b

File tree

6 files changed

+288
-146
lines changed

6 files changed

+288
-146
lines changed

python/pyspark/serializers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ def __ne__(self, other):
114114
def __repr__(self):
115115
return "<%s object>" % self.__class__.__name__
116116

117+
def __hash__(self):
118+
return hash(str(self))
119+
117120

118121
class FramedSerializer(Serializer):
119122

python/pyspark/streaming/context.py

Lines changed: 111 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,51 @@
1515
# limitations under the License.
1616
#
1717

18-
from pyspark.serializers import UTF8Deserializer
18+
from pyspark import RDD
19+
from pyspark.serializers import UTF8Deserializer, BatchedSerializer
1920
from pyspark.context import SparkContext
21+
from pyspark.storagelevel import StorageLevel
2022
from pyspark.streaming.dstream import DStream
21-
from pyspark.streaming.duration import Duration, Seconds
23+
from pyspark.streaming.duration import Seconds
2224

2325
from py4j.java_collections import ListConverter
2426

2527
__all__ = ["StreamingContext"]
2628

2729

30+
def _daemonize_callback_server():
31+
"""
32+
Hack Py4J to daemonize callback server
33+
"""
34+
# TODO: create a patch for Py4J
35+
import socket
36+
import py4j.java_gateway
37+
logger = py4j.java_gateway.logger
38+
from py4j.java_gateway import Py4JNetworkError
39+
from threading import Thread
40+
41+
def start(self):
42+
"""Starts the CallbackServer. This method should be called by the
43+
client instead of run()."""
44+
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
45+
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
46+
1)
47+
try:
48+
self.server_socket.bind((self.address, self.port))
49+
# self.port = self.server_socket.getsockname()[1]
50+
except Exception:
51+
msg = 'An error occurred while trying to start the callback server'
52+
logger.exception(msg)
53+
raise Py4JNetworkError(msg)
54+
55+
# Maybe thread needs to be cleanup up?
56+
self.thread = Thread(target=self.run)
57+
self.thread.daemon = True
58+
self.thread.start()
59+
60+
py4j.java_gateway.CallbackServer.start = start
61+
62+
2863
class StreamingContext(object):
2964
"""
3065
Main entry point for Spark Streaming functionality. A StreamingContext represents the
@@ -53,7 +88,9 @@ def _start_callback_server(self):
5388
gw = self._sc._gateway
5489
# getattr will fallback to JVM
5590
if "_callback_server" not in gw.__dict__:
91+
_daemonize_callback_server()
5692
gw._start_callback_server(gw._python_proxy_port)
93+
gw._python_proxy_port = gw._callback_server.port # update port with real port
5794

5895
def _initialize_context(self, sc, duration):
5996
return self._jvm.JavaStreamingContext(sc._jsc, duration._jduration)
@@ -92,26 +129,44 @@ def stop(self, stopSparkContext=True, stopGraceFully=False):
92129

93130
def remember(self, duration):
94131
"""
95-
Set each DStreams in this context to remember RDDs it generated in the last given duration.
96-
DStreams remember RDDs only for a limited duration of time and releases them for garbage
97-
collection. This method allows the developer to specify how to long to remember the RDDs (
98-
if the developer wishes to query old data outside the DStream computation).
99-
@param duration pyspark.streaming.duration.Duration object or seconds.
100-
Minimum duration that each DStream should remember its RDDs
132+
Set each DStreams in this context to remember RDDs it generated
133+
in the last given duration. DStreams remember RDDs only for a
134+
limited duration of time and releases them for garbage collection.
135+
This method allows the developer to specify how to long to remember
136+
the RDDs ( if the developer wishes to query old data outside the
137+
DStream computation).
138+
139+
@param duration Minimum duration (in seconds) that each DStream
140+
should remember its RDDs
101141
"""
102142
if isinstance(duration, (int, long, float)):
103143
duration = Seconds(duration)
104144

105145
self._jssc.remember(duration._jduration)
106146

107-
# TODO: add storageLevel
108-
def socketTextStream(self, hostname, port):
147+
def checkpoint(self, directory):
148+
"""
149+
Sets the context to periodically checkpoint the DStream operations for master
150+
fault-tolerance. The graph will be checkpointed every batch interval.
151+
152+
@param directory HDFS-compatible directory where the checkpoint data
153+
will be reliably stored
154+
"""
155+
self._jssc.checkpoint(directory)
156+
157+
def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2):
109158
"""
110159
Create an input from TCP source hostname:port. Data is received using
111160
a TCP socket and receive byte is interpreted as UTF8 encoded '\n' delimited
112161
lines.
162+
163+
@param hostname Hostname to connect to for receiving data
164+
@param port Port to connect to for receiving data
165+
@param storageLevel Storage level to use for storing the received objects
113166
"""
114-
return DStream(self._jssc.socketTextStream(hostname, port), self, UTF8Deserializer())
167+
jlevel = self._sc._getJavaStorageLevel(storageLevel)
168+
return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self,
169+
UTF8Deserializer())
115170

116171
def textFileStream(self, directory):
117172
"""
@@ -122,14 +177,52 @@ def textFileStream(self, directory):
122177
"""
123178
return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())
124179

125-
def _makeStream(self, inputs, numSlices=None):
180+
def _check_serialzers(self, rdds):
181+
# make sure they have same serializer
182+
if len(set(rdd._jrdd_deserializer for rdd in rdds)):
183+
for i in range(len(rdds)):
184+
# reset them to sc.serializer
185+
rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True)
186+
187+
def queueStream(self, queue, oneAtATime=False, default=None):
126188
"""
127-
This function is only for unittest.
128-
It requires a list as input, and returns the i_th element at the i_th batch
129-
under manual clock.
189+
Create an input stream from an queue of RDDs or list. In each batch,
190+
it will process either one or all of the RDDs returned by the queue.
191+
192+
NOTE: changes to the queue after the stream is created will not be recognized.
193+
@param queue Queue of RDDs
194+
@tparam T Type of objects in the RDD
130195
"""
131-
rdds = [self._sc.parallelize(input, numSlices) for input in inputs]
196+
if queue and not isinstance(queue[0], RDD):
197+
rdds = [self._sc.parallelize(input) for input in queue]
198+
else:
199+
rdds = queue
200+
self._check_serialzers(rdds)
132201
jrdds = ListConverter().convert([r._jrdd for r in rdds],
133202
SparkContext._gateway._gateway_client)
134-
jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds).asJavaDStream()
135-
return DStream(jdstream, self, rdds[0]._jrdd_deserializer)
203+
jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds, oneAtATime,
204+
default and default._jrdd)
205+
return DStream(jdstream.asJavaDStream(), self, rdds[0]._jrdd_deserializer)
206+
207+
def transform(self, dstreams, transformFunc):
208+
"""
209+
Create a new DStream in which each RDD is generated by applying a function on RDDs of
210+
the DStreams. The order of the JavaRDDs in the transform function parameter will be the
211+
same as the order of corresponding DStreams in the list.
212+
"""
213+
# TODO
214+
215+
def union(self, *dstreams):
216+
"""
217+
Create a unified DStream from multiple DStreams of the same
218+
type and same slide duration.
219+
"""
220+
if not dstreams:
221+
raise ValueError("should have at least one DStream to union")
222+
if len(dstreams) == 1:
223+
return dstreams[0]
224+
self._check_serialzers(dstreams)
225+
first = dstreams[0]
226+
jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
227+
SparkContext._gateway._gateway_client)
228+
return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer)

python/pyspark/streaming/dstream.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,16 +315,16 @@ def repartitions(self, numPartitions):
315315
return self.transform(lambda rdd: rdd.repartition(numPartitions))
316316

317317
def union(self, other):
318-
return self.transformWith(lambda a, b: a.union(b), other, True)
318+
return self.transformWith(lambda a, b, t: a.union(b), other, True)
319319

320320
def cogroup(self, other):
321-
return self.transformWith(lambda a, b: a.cogroup(b), other)
321+
return self.transformWith(lambda a, b, t: a.cogroup(b), other)
322322

323323
def leftOuterJoin(self, other):
324-
return self.transformWith(lambda a, b: a.leftOuterJion(b), other)
324+
return self.transformWith(lambda a, b, t: a.leftOuterJion(b), other)
325325

326326
def rightOuterJoin(self, other):
327-
return self.transformWith(lambda a, b: a.rightOuterJoin(b), other)
327+
return self.transformWith(lambda a, b, t: a.rightOuterJoin(b), other)
328328

329329
def _jtime(self, milliseconds):
330330
return self.ctx._jvm.Time(milliseconds)

python/pyspark/streaming/tests.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,25 @@ def setUp(self):
4040
class_name = self.__class__.__name__
4141
self.sc = SparkContext(appName=class_name)
4242
self.sc.setCheckpointDir("/tmp")
43+
# TODO: decrease duration to speed up tests
4344
self.ssc = StreamingContext(self.sc, duration=Seconds(1))
4445

4546
def tearDown(self):
4647
self.ssc.stop()
47-
self.sc.stop()
4848

4949
@classmethod
5050
def tearDownClass(cls):
5151
# Make sure tp shutdown the callback server
5252
SparkContext._gateway._shutdown_callback_server()
5353

54-
def _test_func(self, input, func, expected, numSlices=None, sort=False):
54+
def _test_func(self, input, func, expected, sort=False):
5555
"""
56-
Start stream and return the result.
5756
@param input: dataset for the test. This should be list of lists.
5857
@param func: wrapped function. This function should return PythonDStream object.
5958
@param expected: expected output for this testcase.
60-
@param numSlices: the number of slices in the rdd in the dstream.
6159
"""
6260
# Generate input stream with user-defined input.
63-
input_stream = self.ssc._makeStream(input, numSlices)
61+
input_stream = self.ssc.queueStream(input)
6462
# Apply test function to stream.
6563
stream = func(input_stream)
6664
result = stream.collect()
@@ -121,7 +119,7 @@ def func(dstream):
121119

122120
def test_count(self):
123121
"""Basic operation test for DStream.count."""
124-
input = [range(1, 5), range(1, 10), range(1, 20)]
122+
input = [range(5), range(10), range(20)]
125123

126124
def func(dstream):
127125
return dstream.count()
@@ -178,24 +176,24 @@ def func(dstream):
178176
def test_glom(self):
179177
"""Basic operation test for DStream.glom."""
180178
input = [range(1, 5), range(5, 9), range(9, 13)]
181-
numSlices = 2
179+
rdds = [self.sc.parallelize(r, 2) for r in input]
182180

183181
def func(dstream):
184182
return dstream.glom()
185183
expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
186-
self._test_func(input, func, expected, numSlices)
184+
self._test_func(rdds, func, expected)
187185

188186
def test_mapPartitions(self):
189187
"""Basic operation test for DStream.mapPartitions."""
190188
input = [range(1, 5), range(5, 9), range(9, 13)]
191-
numSlices = 2
189+
rdds = [self.sc.parallelize(r, 2) for r in input]
192190

193191
def func(dstream):
194192
def f(iterator):
195193
yield sum(iterator)
196194
return dstream.mapPartitions(f)
197195
expected = [[3, 7], [11, 15], [19, 23]]
198-
self._test_func(input, func, expected, numSlices)
196+
self._test_func(rdds, func, expected)
199197

200198
def test_countByValue(self):
201199
"""Basic operation test for DStream.countByValue."""
@@ -236,14 +234,14 @@ def add(a, b):
236234
self._test_func(input, func, expected, sort=True)
237235

238236
def test_union(self):
239-
input1 = [range(3), range(5), range(1)]
237+
input1 = [range(3), range(5), range(1), range(6)]
240238
input2 = [range(3, 6), range(5, 6), range(1, 6)]
241239

242-
d1 = self.ssc._makeStream(input1)
243-
d2 = self.ssc._makeStream(input2)
240+
d1 = self.ssc.queueStream(input1)
241+
d2 = self.ssc.queueStream(input2)
244242
d = d1.union(d2)
245243
result = d.collect()
246-
expected = [range(6), range(6), range(6)]
244+
expected = [range(6), range(6), range(6), range(6)]
247245

248246
self.ssc.start()
249247
start_time = time.time()
@@ -317,33 +315,49 @@ def func(dstream):
317315
class TestStreamingContext(unittest.TestCase):
318316
def setUp(self):
319317
self.sc = SparkContext(master="local[2]", appName=self.__class__.__name__)
320-
self.batachDuration = Seconds(1)
321-
self.ssc = None
318+
self.batachDuration = Seconds(0.1)
319+
self.ssc = StreamingContext(self.sc, self.batachDuration)
322320

323321
def tearDown(self):
324-
if self.ssc is not None:
325-
self.ssc.stop()
322+
self.ssc.stop()
326323
self.sc.stop()
327324

328325
def test_stop_only_streaming_context(self):
329-
self.ssc = StreamingContext(self.sc, self.batachDuration)
330-
self._addInputStream(self.ssc)
326+
self._addInputStream()
331327
self.ssc.start()
332328
self.ssc.stop(False)
333329
self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5)
334330

335331
def test_stop_multiple_times(self):
336-
self.ssc = StreamingContext(self.sc, self.batachDuration)
337-
self._addInputStream(self.ssc)
332+
self._addInputStream()
338333
self.ssc.start()
339334
self.ssc.stop()
340335
self.ssc.stop()
341336

342-
def _addInputStream(self, s):
337+
def _addInputStream(self):
343338
# Make sure each length of input is over 3
344339
inputs = map(lambda x: range(1, x), range(5, 101))
345-
stream = s._makeStream(inputs)
340+
stream = self.ssc.queueStream(inputs)
346341
stream.collect()
347342

343+
def test_queueStream(self):
344+
input = [range(i) for i in range(3)]
345+
dstream = self.ssc.queueStream(input)
346+
result = dstream.collect()
347+
self.ssc.start()
348+
time.sleep(1)
349+
self.assertEqual(input, result)
350+
351+
def test_union(self):
352+
input = [range(i) for i in range(3)]
353+
dstream = self.ssc.queueStream(input)
354+
dstream2 = self.ssc.union(dstream, dstream)
355+
result = dstream.collect()
356+
self.ssc.start()
357+
time.sleep(1)
358+
expected = [i * 2 for i in input]
359+
self.assertEqual(input, result)
360+
361+
348362
if __name__ == "__main__":
349363
unittest.main()

python/pyspark/streaming/util.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def __init__(self, ctx, func, jrdd_deserializer):
3030

3131
def call(self, jrdd, milliseconds):
3232
try:
33-
rdd = RDD(jrdd, self.ctx, self.deserializer)
33+
emptyRDD = getattr(self.ctx, "_emptyRDD", None)
34+
if emptyRDD is None:
35+
self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache()
36+
rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else emptyRDD
3437
r = self.func(rdd, milliseconds)
3538
if r:
3639
return r._jrdd
@@ -58,8 +61,12 @@ def __init__(self, ctx, func, jrdd_deserializer, jrdd_deserializer2=None):
5861

5962
def call(self, jrdd, jrdd2, milliseconds):
6063
try:
61-
rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else None
62-
other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else None
64+
emptyRDD = getattr(self.ctx, "_emptyRDD", None)
65+
if emptyRDD is None:
66+
self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache()
67+
68+
rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else emptyRDD
69+
other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else emptyRDD
6370
r = self.func(rdd, other, milliseconds)
6471
if r:
6572
return r._jrdd

0 commit comments

Comments
 (0)