Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit cd1397d

Browse files
committed
Add a test for the propagation of a new rate limit from driver to receivers.
1 parent 6369b30 commit cd1397d

File tree

5 files changed

+56
-1
lines changed

5 files changed

+56
-1
lines changed

streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
271271
}
272272

273273
/** Get the attached executor. */
274-
private def executor = {
274+
private[streaming] def executor = {
275275
assert(executor_ != null, "Executor has not been attached to this receiver")
276276
executor_
277277
}

streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ private[streaming] abstract class ReceiverSupervisor(
5858
/** Time between a receiver is stopped and started again */
5959
private val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000)
6060

61+
/** The current maximum rate limit for this receiver. */
62+
private[streaming] def getCurrentRateLimit: Option[Int] = None
63+
6164
/** Exception associated with the stopping of the receiver */
6265
@volatile protected var stoppingError: Throwable = null
6366

streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ private[streaming] class ReceiverSupervisorImpl(
100100
}
101101
}, streamId, env.conf)
102102

103+
override private[streaming] def getCurrentRateLimit: Option[Int] =
104+
Some(blockGenerator.currentRateLimit.get)
105+
103106
/** Push a single record of received data into block generator. */
104107
def pushSingle(data: Any) {
105108
blockGenerator.addData(data)

streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,4 +537,19 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging {
537537
verifyOutput[W](output, expectedOutput, useSet)
538538
}
539539
}
540+
541+
/**
542+
* Wait until `cond` becomes true, or timeout ms have passed. This method checks the condition
543+
* every 100ms, so it won't wait more than 100ms more than necessary.
544+
*
545+
* @param cond A boolean that should become `true`
546+
* @param timemout How many millis to wait before giving up
547+
*/
548+
def waitUntil(cond: => Boolean, timeout: Int): Unit = {
549+
val start = System.currentTimeMillis()
550+
val end = start + timeout
551+
while ((System.currentTimeMillis() < end) && !cond) {
552+
Thread.sleep(100)
553+
}
554+
}
540555
}

streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ import org.apache.spark.SparkConf
2222
import org.apache.spark.storage.StorageLevel
2323
import org.apache.spark.streaming.receiver._
2424
import org.apache.spark.util.Utils
25+
import org.apache.spark.streaming.dstream.InputDStream
26+
import scala.reflect.ClassTag
27+
import org.apache.spark.streaming.dstream.ReceiverInputDStream
2528

2629
/** Testsuite for receiver scheduling */
2730
class ReceiverTrackerSuite extends TestSuiteBase {
@@ -72,15 +75,46 @@ class ReceiverTrackerSuite extends TestSuiteBase {
7275
assert(locations(0).length === 1)
7376
assert(locations(3).length === 1)
7477
}
78+
79+
test("Receiver tracker - propagates rate limit") {
80+
val newRateLimit = 100
81+
val ids = new TestReceiverInputDStream(ssc)
82+
val tracker = new ReceiverTracker(ssc)
83+
tracker.start()
84+
waitUntil(TestDummyReceiver.started, 5000)
85+
tracker.sendRateUpdate(ids.id, newRateLimit)
86+
// this is an async message, we need to wait a bit for it to be processed
87+
waitUntil(ids.getRateLimit.get == newRateLimit, 1000)
88+
assert(ids.getRateLimit.get === newRateLimit)
89+
}
90+
}
91+
92+
/** An input DStream with a hard-coded receiver that gives access to internals for testing. */
93+
private class TestReceiverInputDStream(@transient ssc_ : StreamingContext)
94+
extends ReceiverInputDStream[Int](ssc_) {
95+
96+
override def getReceiver(): DummyReceiver = TestDummyReceiver
97+
98+
def getRateLimit: Option[Int] =
99+
TestDummyReceiver.executor.getCurrentRateLimit
75100
}
76101

102+
/**
103+
* We need the receiver to be an object, otherwise serialization will create another one
104+
* and we won't be able to read its rate limit.
105+
*/
106+
private object TestDummyReceiver extends DummyReceiver
107+
77108
/**
78109
* Dummy receiver implementation
79110
*/
80111
private class DummyReceiver(host: Option[String] = None)
81112
extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
82113

114+
var started = false
115+
83116
def onStart() {
117+
started = true
84118
}
85119

86120
def onStop() {

0 commit comments

Comments
 (0)