Skip to content

Commit 028bde6

Browse files
committed
Further refactored receiver to allow restarting of a receiver.
1 parent 43f5290 commit 028bde6

File tree

8 files changed

+287
-100
lines changed

8 files changed

+287
-100
lines changed

external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,17 @@ class MQTTReceiver(
6969
storageLevel: StorageLevel
7070
) extends NetworkReceiver[String](storageLevel) {
7171

72-
def onStop() { }
72+
def onStop() {
73+
74+
}
7375

7476
def onStart() {
7577

7678
// Set up persistence for messages
77-
val peristance: MqttClientPersistence = new MemoryPersistence()
79+
val persistence = new MemoryPersistence()
7880

7981
// Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance
80-
val client: MqttClient = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance)
82+
val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence)
8183

8284
// Connect to MqttBroker
8385
client.connect()
@@ -97,8 +99,7 @@ class MQTTReceiver(
9799
}
98100

99101
override def connectionLost(arg0: Throwable) {
100-
reportError("Connection lost ", arg0)
101-
stop()
102+
restart("Connection lost ", arg0)
102103
}
103104
}
104105

external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,11 @@ class TwitterReceiver(
7777
def onScrubGeo(l: Long, l1: Long) {}
7878
def onStallWarning(stallWarning: StallWarning) {}
7979
def onException(e: Exception) {
80-
reportError("Error receiving tweets", e)
81-
stop()
80+
restart("Error receiving tweets", e)
8281
}
8382
})
8483

85-
val query: FilterQuery = new FilterQuery
84+
val query = new FilterQuery
8685
if (filters.size > 0) {
8786
query.track(filters.toArray)
8887
twitterStream.filter(query)

streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.util.NextIterator
2424
import scala.reflect.ClassTag
2525

2626
import java.io._
27-
import java.net.Socket
27+
import java.net.{UnknownHostException, Socket}
2828
import org.apache.spark.Logging
2929
import org.apache.spark.streaming.receiver.NetworkReceiver
3030

@@ -51,19 +51,49 @@ class SocketReceiver[T: ClassTag](
5151
) extends NetworkReceiver[T](storageLevel) with Logging {
5252

5353
var socket: Socket = null
54+
var receivingThread: Thread = null
5455

5556
def onStart() {
56-
logInfo("Connecting to " + host + ":" + port)
57-
socket = new Socket(host, port)
58-
logInfo("Connected to " + host + ":" + port)
59-
val iterator = bytesToObjects(socket.getInputStream())
60-
while(!isStopped && iterator.hasNext) {
61-
store(iterator.next)
57+
receivingThread = new Thread("Socket Receiver") {
58+
override def run() {
59+
connect()
60+
receive()
61+
}
6262
}
63+
receivingThread.start()
6364
}
6465

6566
def onStop() {
66-
if (socket != null) socket.close()
67+
if (socket != null) {
68+
socket.close()
69+
}
70+
socket = null
71+
if (receivingThread != null) {
72+
receivingThread.join()
73+
}
74+
}
75+
76+
def connect() {
77+
try {
78+
logInfo("Connecting to " + host + ":" + port)
79+
socket = new Socket(host, port)
80+
} catch {
81+
case e: Exception =>
82+
restart("Could not connect to " + host + ":" + port, e)
83+
}
84+
}
85+
86+
def receive() {
87+
try {
88+
logInfo("Connected to " + host + ":" + port)
89+
val iterator = bytesToObjects(socket.getInputStream())
90+
while(!isStopped && iterator.hasNext) {
91+
store(iterator.next)
92+
}
93+
} catch {
94+
case e: Exception =>
95+
restart("Error receiving data from socket", e)
96+
}
6797
}
6898
}
6999

streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiver.scala

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,14 @@ import org.apache.spark.storage.StorageLevel
3333
* class MyReceiver(storageLevel) extends NetworkReceiver[String](storageLevel) {
3434
* def onStart() {
3535
* // Setup stuff (start threads, open sockets, etc.) to start receiving data.
36-
* // Call store(...) to store received data into Spark's memory.
37-
* // Optionally, wait for other threads to complete or watch for exceptions.
38-
* // Call reportError(...) if there is an error that you cannot ignore and need
39-
* // the receiver to be terminated.
36+
* // Must start new thread to receive data, as onStart() must be non-blocking.
37+
*
38+
* // Call store(...) in those threads to store received data into Spark's memory.
39+
*
40+
* // Call stop(...), restart() or reportError(...) on any thread based on how
41+
* // different errors should be handled.
42+
*
43+
* // See corresponding method documentation for more details.
4044
* }
4145
*
4246
* def onStop() {
@@ -47,17 +51,24 @@ import org.apache.spark.storage.StorageLevel
4751
abstract class NetworkReceiver[T](val storageLevel: StorageLevel) extends Serializable {
4852

4953
/**
50-
* This method is called by the system when the receiver is started to start receiving data.
51-
* All threads and resources set up in this method must be cleaned up in onStop().
52-
* If there are exceptions on other threads such that the receiver must be terminated,
53-
* then you must call reportError(exception). However, the thread that called onStart() must
54-
* never catch and ignore InterruptedException (it can catch and rethrow).
54+
* This method is called by the system when the receiver is started. This function
55+
* must initialize all resources (threads, buffers, etc.) necessary for receiving data.
56+
* This function must be non-blocking, so receiving the data must occur on a different
57+
* thread. Received data can be stored with Spark by calling `store(data)`.
58+
*
59+
* If there are errors in threads started here, then following options can be done
60+
* (i) `reportError(...)` can be called to report the error to the driver.
61+
* The receiving of data will continue uninterrupted.
62+
* (ii) `stop(...)` can be called to stop receiving data. This will call `onStop()` to
63+
* clear up all resources allocated (threads, buffers, etc.) during `onStart()`.
64+
* (iii) `restart(...)` can be called to restart the receiver. This will call `onStop()`
65+
* immediately, and then `onStart()` after a delay.
5566
*/
5667
def onStart()
5768

5869
/**
59-
* This method is called by the system when the receiver is stopped to stop receiving data.
60-
* All threads and resources setup in onStart() must be cleaned up in this method.
70+
* This method is called by the system when the receiver is stopped. All resources
71+
* (threads, buffers, etc.) setup in `onStart()` must be cleaned up in this method.
6172
*/
6273
def onStop()
6374

@@ -95,6 +106,7 @@ abstract class NetworkReceiver[T](val storageLevel: StorageLevel) extends Serial
95106
def store(dataIterator: Iterator[T], metadata: Any) {
96107
executor.pushIterator(dataIterator, Some(metadata), None)
97108
}
109+
98110
/** Store the bytes of received data into Spark's memory. */
99111
def store(bytes: ByteBuffer) {
100112
executor.pushBytes(bytes, None, None)
@@ -107,24 +119,70 @@ abstract class NetworkReceiver[T](val storageLevel: StorageLevel) extends Serial
107119
def store(bytes: ByteBuffer, metadata: Any = null) {
108120
executor.pushBytes(bytes, Some(metadata), None)
109121
}
122+
110123
/** Report exceptions in receiving data. */
111124
def reportError(message: String, throwable: Throwable) {
112125
executor.reportError(message, throwable)
113126
}
114127

115-
/** Stop the receiver. */
116-
def stop() {
117-
executor.stop()
128+
/**
129+
* Restart the receiver. This will call `onStop()` immediately and return.
130+
* Asynchronously, after a delay, `onStart()` will be called.
131+
* The `message` will be reported to the driver.
132+
* The delay is defined by the Spark configuration
133+
* `spark.streaming.receiverRestartDelay`.
134+
*/
135+
def restart(message: String) {
136+
executor.restartReceiver(message)
137+
}
138+
139+
/**
140+
* Restart the receiver. This will call `onStop()` immediately and return.
141+
* Asynchronously, after a delay, `onStart()` will be called.
142+
* The `message` and `exception` will be reported to the driver.
143+
* The delay is defined by the Spark configuration
144+
* `spark.streaming.receiverRestartDelay`.
145+
*/
146+
def restart(message: String, exception: Throwable) {
147+
executor.restartReceiver(message, exception)
148+
}
149+
150+
/**
151+
* Restart the receiver. This will call `onStop()` immediately and return.
152+
* Asynchronously, after the given delay, `onStart()` will be called.
153+
*/
154+
def restart(message: String, throwable: Throwable, millisecond: Int) {
155+
executor.restartReceiver(message, throwable, millisecond)
156+
}
157+
158+
/** Stop the receiver completely. */
159+
def stop(message: String) {
160+
executor.stop(message)
161+
}
162+
163+
/** Stop the receiver completely due to an exception */
164+
def stop(message: String, exception: Throwable) {
165+
executor.stop(message, exception)
166+
}
167+
168+
def isStarted(): Boolean = {
169+
executor.isReceiverStarted()
118170
}
119171

120172
/** Check if receiver has been marked for stopping. */
121173
def isStopped(): Boolean = {
122-
executor.isStopped
174+
!executor.isReceiverStarted()
123175
}
124176

125177
/** Get unique identifier of this receiver. */
126178
def receiverId = id
127179

180+
/*
181+
* =================
182+
* Private methods
183+
* =================
184+
*/
185+
128186
/** Identifier of the stream this receiver is associated with. */
129187
private var id: Int = -1
130188

0 commit comments

Comments
 (0)