15
15
# limitations under the License.
16
16
#
17
17
18
- from pyspark .serializers import UTF8Deserializer
18
+ from pyspark import RDD
19
+ from pyspark .serializers import UTF8Deserializer , BatchedSerializer
19
20
from pyspark .context import SparkContext
21
+ from pyspark .storagelevel import StorageLevel
20
22
from pyspark .streaming .dstream import DStream
21
- from pyspark .streaming .duration import Duration , Seconds
23
+ from pyspark .streaming .duration import Seconds
22
24
23
25
from py4j .java_collections import ListConverter
24
26
25
27
__all__ = ["StreamingContext" ]
26
28
27
29
30
+ def _daemonize_callback_server ():
31
+ """
32
+ Hack Py4J to daemonize callback server
33
+ """
34
+ # TODO: create a patch for Py4J
35
+ import socket
36
+ import py4j .java_gateway
37
+ logger = py4j .java_gateway .logger
38
+ from py4j .java_gateway import Py4JNetworkError
39
+ from threading import Thread
40
+
41
+ def start (self ):
42
+ """Starts the CallbackServer. This method should be called by the
43
+ client instead of run()."""
44
+ self .server_socket = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
45
+ self .server_socket .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR ,
46
+ 1 )
47
+ try :
48
+ self .server_socket .bind ((self .address , self .port ))
49
+ # self.port = self.server_socket.getsockname()[1]
50
+ except Exception :
51
+ msg = 'An error occurred while trying to start the callback server'
52
+ logger .exception (msg )
53
+ raise Py4JNetworkError (msg )
54
+
55
+ # Maybe thread needs to be cleanup up?
56
+ self .thread = Thread (target = self .run )
57
+ self .thread .daemon = True
58
+ self .thread .start ()
59
+
60
+ py4j .java_gateway .CallbackServer .start = start
61
+
62
+
28
63
class StreamingContext (object ):
29
64
"""
30
65
Main entry point for Spark Streaming functionality. A StreamingContext represents the
@@ -53,7 +88,9 @@ def _start_callback_server(self):
53
88
gw = self ._sc ._gateway
54
89
# getattr will fallback to JVM
55
90
if "_callback_server" not in gw .__dict__ :
91
+ _daemonize_callback_server ()
56
92
gw ._start_callback_server (gw ._python_proxy_port )
93
+ gw ._python_proxy_port = gw ._callback_server .port # update port with real port
57
94
58
95
def _initialize_context (self , sc , duration ):
59
96
return self ._jvm .JavaStreamingContext (sc ._jsc , duration ._jduration )
@@ -92,26 +129,44 @@ def stop(self, stopSparkContext=True, stopGraceFully=False):
92
129
93
130
def remember (self , duration ):
94
131
"""
95
- Set each DStreams in this context to remember RDDs it generated in the last given duration.
96
- DStreams remember RDDs only for a limited duration of time and releases them for garbage
97
- collection. This method allows the developer to specify how to long to remember the RDDs (
98
- if the developer wishes to query old data outside the DStream computation).
99
- @param duration pyspark.streaming.duration.Duration object or seconds.
100
- Minimum duration that each DStream should remember its RDDs
132
+ Set each DStreams in this context to remember RDDs it generated
133
+ in the last given duration. DStreams remember RDDs only for a
134
+ limited duration of time and releases them for garbage collection.
135
+ This method allows the developer to specify how to long to remember
136
+ the RDDs ( if the developer wishes to query old data outside the
137
+ DStream computation).
138
+
139
+ @param duration Minimum duration (in seconds) that each DStream
140
+ should remember its RDDs
101
141
"""
102
142
if isinstance (duration , (int , long , float )):
103
143
duration = Seconds (duration )
104
144
105
145
self ._jssc .remember (duration ._jduration )
106
146
107
- # TODO: add storageLevel
108
- def socketTextStream (self , hostname , port ):
147
+ def checkpoint (self , directory ):
148
+ """
149
+ Sets the context to periodically checkpoint the DStream operations for master
150
+ fault-tolerance. The graph will be checkpointed every batch interval.
151
+
152
+ @param directory HDFS-compatible directory where the checkpoint data
153
+ will be reliably stored
154
+ """
155
+ self ._jssc .checkpoint (directory )
156
+
157
+ def socketTextStream (self , hostname , port , storageLevel = StorageLevel .MEMORY_AND_DISK_SER_2 ):
109
158
"""
110
159
Create an input from TCP source hostname:port. Data is received using
111
160
a TCP socket and receive byte is interpreted as UTF8 encoded '\n ' delimited
112
161
lines.
162
+
163
+ @param hostname Hostname to connect to for receiving data
164
+ @param port Port to connect to for receiving data
165
+ @param storageLevel Storage level to use for storing the received objects
113
166
"""
114
- return DStream (self ._jssc .socketTextStream (hostname , port ), self , UTF8Deserializer ())
167
+ jlevel = self ._sc ._getJavaStorageLevel (storageLevel )
168
+ return DStream (self ._jssc .socketTextStream (hostname , port , jlevel ), self ,
169
+ UTF8Deserializer ())
115
170
116
171
def textFileStream (self , directory ):
117
172
"""
@@ -122,14 +177,52 @@ def textFileStream(self, directory):
122
177
"""
123
178
return DStream (self ._jssc .textFileStream (directory ), self , UTF8Deserializer ())
124
179
125
- def _makeStream (self , inputs , numSlices = None ):
180
+ def _check_serialzers (self , rdds ):
181
+ # make sure they have same serializer
182
+ if len (set (rdd ._jrdd_deserializer for rdd in rdds )):
183
+ for i in range (len (rdds )):
184
+ # reset them to sc.serializer
185
+ rdds [i ] = rdds [i ].map (lambda x : x , preservesPartitioning = True )
186
+
187
+ def queueStream (self , queue , oneAtATime = False , default = None ):
126
188
"""
127
- This function is only for unittest.
128
- It requires a list as input, and returns the i_th element at the i_th batch
129
- under manual clock.
189
+ Create an input stream from an queue of RDDs or list. In each batch,
190
+ it will process either one or all of the RDDs returned by the queue.
191
+
192
+ NOTE: changes to the queue after the stream is created will not be recognized.
193
+ @param queue Queue of RDDs
194
+ @tparam T Type of objects in the RDD
130
195
"""
131
- rdds = [self ._sc .parallelize (input , numSlices ) for input in inputs ]
196
+ if queue and not isinstance (queue [0 ], RDD ):
197
+ rdds = [self ._sc .parallelize (input ) for input in queue ]
198
+ else :
199
+ rdds = queue
200
+ self ._check_serialzers (rdds )
132
201
jrdds = ListConverter ().convert ([r ._jrdd for r in rdds ],
133
202
SparkContext ._gateway ._gateway_client )
134
- jdstream = self ._jvm .PythonDataInputStream (self ._jssc , jrdds ).asJavaDStream ()
135
- return DStream (jdstream , self , rdds [0 ]._jrdd_deserializer )
203
+ jdstream = self ._jvm .PythonDataInputStream (self ._jssc , jrdds , oneAtATime ,
204
+ default and default ._jrdd )
205
+ return DStream (jdstream .asJavaDStream (), self , rdds [0 ]._jrdd_deserializer )
206
+
207
+ def transform (self , dstreams , transformFunc ):
208
+ """
209
+ Create a new DStream in which each RDD is generated by applying a function on RDDs of
210
+ the DStreams. The order of the JavaRDDs in the transform function parameter will be the
211
+ same as the order of corresponding DStreams in the list.
212
+ """
213
+ # TODO
214
+
215
+ def union (self , * dstreams ):
216
+ """
217
+ Create a unified DStream from multiple DStreams of the same
218
+ type and same slide duration.
219
+ """
220
+ if not dstreams :
221
+ raise ValueError ("should have at least one DStream to union" )
222
+ if len (dstreams ) == 1 :
223
+ return dstreams [0 ]
224
+ self ._check_serialzers (dstreams )
225
+ first = dstreams [0 ]
226
+ jrest = ListConverter ().convert ([d ._jdstream for d in dstreams [1 :]],
227
+ SparkContext ._gateway ._gateway_client )
228
+ return DStream (self ._jssc .union (first ._jdstream , jrest ), self , first ._jrdd_deserializer )
0 commit comments