17
17
18
18
package org .apache .spark .streaming .mqtt
19
19
20
- import java .net .{URI , ServerSocket }
21
- import java .util .concurrent .CountDownLatch
22
- import java .util .concurrent .TimeUnit
23
-
24
20
import scala .concurrent .duration ._
25
21
import scala .language .postfixOps
26
22
27
- import org .apache .activemq .broker .{TransportConnector , BrokerService }
28
- import org .apache .commons .lang3 .RandomUtils
29
- import org .eclipse .paho .client .mqttv3 ._
30
- import org .eclipse .paho .client .mqttv3 .persist .MqttDefaultFilePersistence
31
-
32
- import org .scalatest .BeforeAndAfter
23
+ import org .scalatest .BeforeAndAfterAll
33
24
import org .scalatest .concurrent .Eventually
34
25
35
- import org .apache .spark .streaming .{Milliseconds , StreamingContext }
36
- import org .apache .spark .storage .StorageLevel
37
- import org .apache .spark .streaming .dstream .ReceiverInputDStream
38
- import org .apache .spark .streaming .scheduler .StreamingListener
39
- import org .apache .spark .streaming .scheduler .StreamingListenerReceiverStarted
40
26
import org .apache .spark .{SparkConf , SparkFunSuite }
41
- import org .apache .spark .util .Utils
42
-
43
- class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {
27
+ import org .apache .spark .storage .StorageLevel
28
+ import org .apache .spark .streaming .{Milliseconds , StreamingContext }
44
29
45
- private val batchDuration = Milliseconds (500 )
46
- private val master = " local[2]"
47
- private val framework = this .getClass.getSimpleName
48
- private val freePort = findFreePort()
49
- private val brokerUri = " //localhost:" + freePort
50
- private val topic = " def"
51
- private val persistenceDir = Utils .createTempDir()
30
+ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll {
52
31
32
+ private val topic = " topic"
53
33
private var ssc : StreamingContext = _
54
- private var broker : BrokerService = _
55
- private var connector : TransportConnector = _
34
+ private var MQTTTestUtils : MQTTTestUtils = _
56
35
57
- before {
58
- ssc = new StreamingContext (master, framework, batchDuration)
59
- setupMQTT ()
36
+ override def beforeAll () : Unit = {
37
+ MQTTTestUtils = new MQTTTestUtils
38
+ MQTTTestUtils .setup ()
60
39
}
61
40
62
- after {
41
+ override def afterAll () : Unit = {
63
42
if (ssc != null ) {
64
43
ssc.stop()
65
44
ssc = null
66
45
}
67
- Utils .deleteRecursively(persistenceDir)
68
- tearDownMQTT()
46
+
47
+ if (MQTTTestUtils != null ) {
48
+ MQTTTestUtils .teardown()
49
+ MQTTTestUtils = null
50
+ }
69
51
}
70
52
71
53
test(" mqtt input stream" ) {
54
+ val sparkConf = new SparkConf ().setMaster(" local[4]" ).setAppName(this .getClass.getSimpleName)
55
+ ssc = new StreamingContext (sparkConf, Milliseconds (500 ))
72
56
val sendMessage = " MQTT demo for spark streaming"
73
57
val receiveStream =
74
- MQTTUtils .createStream(ssc, " tcp:" + brokerUri, topic, StorageLevel .MEMORY_ONLY )
58
+ MQTTUtils .createStream(ssc, " tcp:// " + MQTTTestUtils . brokerUri, topic, StorageLevel .MEMORY_ONLY )
75
59
@ volatile var receiveMessage : List [String ] = List ()
76
60
receiveStream.foreachRDD { rdd =>
77
61
if (rdd.collect.length > 0 ) {
@@ -83,85 +67,13 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter
83
67
84
68
// wait for the receiver to start before publishing data, or we risk failing
85
69
// the test nondeterministically. See SPARK-4631
86
- waitForReceiverToStart()
70
+ MQTTTestUtils .waitForReceiverToStart(ssc)
71
+
72
+ MQTTTestUtils .publishData(topic, sendMessage)
87
73
88
- publishData(sendMessage)
89
74
eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
90
75
assert(sendMessage.equals(receiveMessage(0 )))
91
76
}
92
77
ssc.stop()
93
78
}
94
-
95
- private def setupMQTT () {
96
- broker = new BrokerService ()
97
- broker.setDataDirectoryFile(Utils .createTempDir())
98
- connector = new TransportConnector ()
99
- connector.setName(" mqtt" )
100
- connector.setUri(new URI (" mqtt:" + brokerUri))
101
- broker.addConnector(connector)
102
- broker.start()
103
- }
104
-
105
- private def tearDownMQTT () {
106
- if (broker != null ) {
107
- broker.stop()
108
- broker = null
109
- }
110
- if (connector != null ) {
111
- connector.stop()
112
- connector = null
113
- }
114
- }
115
-
116
- private def findFreePort (): Int = {
117
- val candidatePort = RandomUtils .nextInt(1024 , 65536 )
118
- Utils .startServiceOnPort(candidatePort, (trialPort : Int ) => {
119
- val socket = new ServerSocket (trialPort)
120
- socket.close()
121
- (null , trialPort)
122
- }, new SparkConf ())._2
123
- }
124
-
125
- def publishData (data : String ): Unit = {
126
- var client : MqttClient = null
127
- try {
128
- val persistence = new MqttDefaultFilePersistence (persistenceDir.getAbsolutePath)
129
- client = new MqttClient (" tcp:" + brokerUri, MqttClient .generateClientId(), persistence)
130
- client.connect()
131
- if (client.isConnected) {
132
- val msgTopic = client.getTopic(topic)
133
- val message = new MqttMessage (data.getBytes(" utf-8" ))
134
- message.setQos(1 )
135
- message.setRetained(true )
136
-
137
- for (i <- 0 to 10 ) {
138
- try {
139
- msgTopic.publish(message)
140
- } catch {
141
- case e : MqttException if e.getReasonCode == MqttException .REASON_CODE_MAX_INFLIGHT =>
142
- // wait for Spark streaming to consume something from the message queue
143
- Thread .sleep(50 )
144
- }
145
- }
146
- }
147
- } finally {
148
- client.disconnect()
149
- client.close()
150
- client = null
151
- }
152
- }
153
-
154
- /**
155
- * Block until at least one receiver has started or timeout occurs.
156
- */
157
- private def waitForReceiverToStart () = {
158
- val latch = new CountDownLatch (1 )
159
- ssc.addStreamingListener(new StreamingListener {
160
- override def onReceiverStarted (receiverStarted : StreamingListenerReceiverStarted ) {
161
- latch.countDown()
162
- }
163
- })
164
-
165
- assert(latch.await(10 , TimeUnit .SECONDS ), " Timeout waiting for receiver to start." )
166
- }
167
79
}
0 commit comments