Skip to content

Commit a5a8f9f

Browse files
committed
added Python test
1 parent 9767d82 commit a5a8f9f

File tree

7 files changed

+216
-114
lines changed

7 files changed

+216
-114
lines changed

dev/run-tests.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ def build_spark_sbt(hadoop_version):
294294
sbt_goals = ["package",
295295
"assembly/assembly",
296296
"streaming-kafka-assembly/assembly",
297-
"streaming-flume-assembly/assembly"]
297+
"streaming-flume-assembly/assembly",
298+
"streaming-mqtt-assembly/assembly"]
298299
profiles_and_goals = build_profiles + sbt_goals
299300

300301
print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ",

dev/sparktestsupport/modules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def contains_file(self, filename):
170170
dependencies=[streaming],
171171
source_file_regexes=[
172172
"external/mqtt",
173+
"external/mqtt-assembly",
173174
],
174175
sbt_test_goals=[
175176
"streaming-mqtt/test",
@@ -290,7 +291,7 @@ def contains_file(self, filename):
290291

291292
pyspark_streaming = Module(
292293
name="pyspark-streaming",
293-
dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly],
294+
dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly, streaming_mqtt],
294295
source_file_regexes=[
295296
"python/pyspark/streaming"
296297
],

docs/streaming-programming-guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea
683683
{:.no_toc}
684684

685685
<span class="badge" style="background-color: grey">Python API</span> As of Spark {{site.SPARK_VERSION_SHORT}},
686-
out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future.
686+
out of these sources, *only* Kafka, Flume and MQTT are available in the Python API. We will add more advanced sources in the Python API in future.
687687

688688
This category of sources require interfacing with external non-Spark libraries, some of them with
689689
complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts

external/mqtt/pom.xml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@
7272
<groupId>org.apache.activemq</groupId>
7373
<artifactId>activemq-core</artifactId>
7474
<version>5.7.0</version>
75-
<scope>test</scope>
7675
</dependency>
7776
</dependencies>
7877
<build>
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.streaming.mqtt
19+
20+
import java.net.{ServerSocket, URI}
21+
import java.util.concurrent.{TimeUnit, CountDownLatch}
22+
23+
import scala.language.postfixOps
24+
25+
import org.apache.activemq.broker.{BrokerService, TransportConnector}
26+
import org.apache.commons.lang3.RandomUtils
27+
import org.eclipse.paho.client.mqttv3._
28+
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
29+
30+
import org.apache.spark.streaming.{StreamingContext, Milliseconds}
31+
import org.apache.spark.streaming.scheduler.StreamingListener
32+
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
33+
import org.apache.spark.util.Utils
34+
import org.apache.spark.{Logging, SparkConf}
35+
36+
/**
37+
* Share codes for Scala and Python unit tests
38+
*/
39+
private class MQTTTestUtils extends Logging {
40+
41+
private val persistenceDir = Utils.createTempDir()
42+
private val brokerHost = "localhost"
43+
private var brokerPort = findFreePort()
44+
45+
private var broker: BrokerService = _
46+
private var connector: TransportConnector = _
47+
48+
def brokerUri: String = {
49+
s"$brokerHost:$brokerPort"
50+
}
51+
52+
def setup(): Unit = {
53+
broker = new BrokerService()
54+
broker.setDataDirectoryFile(Utils.createTempDir())
55+
connector = new TransportConnector()
56+
connector.setName("mqtt")
57+
connector.setUri(new URI("mqtt://" + brokerUri))
58+
broker.addConnector(connector)
59+
broker.start()
60+
}
61+
62+
def teardown(): Unit = {
63+
if (broker != null) {
64+
broker.stop()
65+
broker = null
66+
}
67+
if (connector != null) {
68+
connector.stop()
69+
connector = null
70+
}
71+
}
72+
73+
private def findFreePort(): Int = {
74+
val candidatePort = RandomUtils.nextInt(1024, 65536)
75+
Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
76+
val socket = new ServerSocket(trialPort)
77+
socket.close()
78+
(null, trialPort)
79+
}, new SparkConf())._2
80+
}
81+
82+
def publishData(topic: String, data: String): Unit = {
83+
var client: MqttClient = null
84+
try {
85+
val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath)
86+
client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence)
87+
client.connect()
88+
if (client.isConnected) {
89+
val msgTopic = client.getTopic(topic)
90+
val message = new MqttMessage(data.getBytes("utf-8"))
91+
message.setQos(1)
92+
message.setRetained(true)
93+
94+
for (i <- 0 to 10) {
95+
try {
96+
msgTopic.publish(message)
97+
} catch {
98+
case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
99+
// wait for Spark streaming to consume something from the message queue
100+
Thread.sleep(50)
101+
}
102+
}
103+
}
104+
} finally {
105+
client.disconnect()
106+
client.close()
107+
client = null
108+
}
109+
}
110+
111+
/**
112+
* Block until at least one receiver has started or timeout occurs.
113+
*/
114+
def waitForReceiverToStart(ssc: StreamingContext) = {
115+
val latch = new CountDownLatch(1)
116+
ssc.addStreamingListener(new StreamingListener {
117+
override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
118+
latch.countDown()
119+
}
120+
})
121+
122+
assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
123+
}
124+
}

external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala

Lines changed: 21 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -17,61 +17,45 @@
1717

1818
package org.apache.spark.streaming.mqtt
1919

20-
import java.net.{URI, ServerSocket}
21-
import java.util.concurrent.CountDownLatch
22-
import java.util.concurrent.TimeUnit
23-
2420
import scala.concurrent.duration._
2521
import scala.language.postfixOps
2622

27-
import org.apache.activemq.broker.{TransportConnector, BrokerService}
28-
import org.apache.commons.lang3.RandomUtils
29-
import org.eclipse.paho.client.mqttv3._
30-
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
31-
32-
import org.scalatest.BeforeAndAfter
23+
import org.scalatest.BeforeAndAfterAll
3324
import org.scalatest.concurrent.Eventually
3425

35-
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
36-
import org.apache.spark.storage.StorageLevel
37-
import org.apache.spark.streaming.dstream.ReceiverInputDStream
38-
import org.apache.spark.streaming.scheduler.StreamingListener
39-
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
4026
import org.apache.spark.{SparkConf, SparkFunSuite}
41-
import org.apache.spark.util.Utils
42-
43-
class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {
27+
import org.apache.spark.storage.StorageLevel
28+
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
4429

45-
private val batchDuration = Milliseconds(500)
46-
private val master = "local[2]"
47-
private val framework = this.getClass.getSimpleName
48-
private val freePort = findFreePort()
49-
private val brokerUri = "//localhost:" + freePort
50-
private val topic = "def"
51-
private val persistenceDir = Utils.createTempDir()
30+
class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll {
5231

32+
private val topic = "topic"
5333
private var ssc: StreamingContext = _
54-
private var broker: BrokerService = _
55-
private var connector: TransportConnector = _
34+
private var MQTTTestUtils: MQTTTestUtils = _
5635

57-
before {
58-
ssc = new StreamingContext(master, framework, batchDuration)
59-
setupMQTT()
36+
override def beforeAll(): Unit = {
37+
MQTTTestUtils = new MQTTTestUtils
38+
MQTTTestUtils.setup()
6039
}
6140

62-
after {
41+
override def afterAll(): Unit = {
6342
if (ssc != null) {
6443
ssc.stop()
6544
ssc = null
6645
}
67-
Utils.deleteRecursively(persistenceDir)
68-
tearDownMQTT()
46+
47+
if (MQTTTestUtils != null) {
48+
MQTTTestUtils.teardown()
49+
MQTTTestUtils = null
50+
}
6951
}
7052

7153
test("mqtt input stream") {
54+
val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
55+
ssc = new StreamingContext(sparkConf, Milliseconds(500))
7256
val sendMessage = "MQTT demo for spark streaming"
7357
val receiveStream =
74-
MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY)
58+
MQTTUtils.createStream(ssc, "tcp://" + MQTTTestUtils.brokerUri, topic, StorageLevel.MEMORY_ONLY)
7559
@volatile var receiveMessage: List[String] = List()
7660
receiveStream.foreachRDD { rdd =>
7761
if (rdd.collect.length > 0) {
@@ -83,85 +67,13 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter
8367

8468
// wait for the receiver to start before publishing data, or we risk failing
8569
// the test nondeterministically. See SPARK-4631
86-
waitForReceiverToStart()
70+
MQTTTestUtils.waitForReceiverToStart(ssc)
71+
72+
MQTTTestUtils.publishData(topic, sendMessage)
8773

88-
publishData(sendMessage)
8974
eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
9075
assert(sendMessage.equals(receiveMessage(0)))
9176
}
9277
ssc.stop()
9378
}
94-
95-
private def setupMQTT() {
96-
broker = new BrokerService()
97-
broker.setDataDirectoryFile(Utils.createTempDir())
98-
connector = new TransportConnector()
99-
connector.setName("mqtt")
100-
connector.setUri(new URI("mqtt:" + brokerUri))
101-
broker.addConnector(connector)
102-
broker.start()
103-
}
104-
105-
private def tearDownMQTT() {
106-
if (broker != null) {
107-
broker.stop()
108-
broker = null
109-
}
110-
if (connector != null) {
111-
connector.stop()
112-
connector = null
113-
}
114-
}
115-
116-
private def findFreePort(): Int = {
117-
val candidatePort = RandomUtils.nextInt(1024, 65536)
118-
Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
119-
val socket = new ServerSocket(trialPort)
120-
socket.close()
121-
(null, trialPort)
122-
}, new SparkConf())._2
123-
}
124-
125-
def publishData(data: String): Unit = {
126-
var client: MqttClient = null
127-
try {
128-
val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath)
129-
client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence)
130-
client.connect()
131-
if (client.isConnected) {
132-
val msgTopic = client.getTopic(topic)
133-
val message = new MqttMessage(data.getBytes("utf-8"))
134-
message.setQos(1)
135-
message.setRetained(true)
136-
137-
for (i <- 0 to 10) {
138-
try {
139-
msgTopic.publish(message)
140-
} catch {
141-
case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
142-
// wait for Spark streaming to consume something from the message queue
143-
Thread.sleep(50)
144-
}
145-
}
146-
}
147-
} finally {
148-
client.disconnect()
149-
client.close()
150-
client = null
151-
}
152-
}
153-
154-
/**
155-
* Block until at least one receiver has started or timeout occurs.
156-
*/
157-
private def waitForReceiverToStart() = {
158-
val latch = new CountDownLatch(1)
159-
ssc.addStreamingListener(new StreamingListener {
160-
override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
161-
latch.countDown()
162-
}
163-
})
164-
165-
assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
166-
}
16779
}

0 commit comments

Comments
 (0)