Skip to content

Commit c9b6109

Browse files
committed
Simplify test + make access to akka frame size more modular
1 parent 281d7c9 commit c9b6109

File tree

3 files changed

+36
-29
lines changed

3 files changed

+36
-29
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
3737

3838
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf)
3939
extends Actor with Logging {
40-
val maxAkkaFrameSize = conf.getInt("spark.akka.frameSize", 10) * 1024 * 1024 // MB
40+
val maxAkkaFrameSize = AkkaUtils.maxFrameSize(conf) * 1024 * 1024 // MB
4141

4242
def receive = {
4343
case GetMapOutputStatuses(shuffleId: Int) =>

core/src/main/scala/org/apache/spark/util/AkkaUtils.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ private[spark] object AkkaUtils extends Logging {
4949

5050
val akkaTimeout = conf.getInt("spark.akka.timeout", 100)
5151

52-
val akkaFrameSize = conf.getInt("spark.akka.frameSize", 10)
52+
val akkaFrameSize = maxFrameSize(conf)
5353
val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false)
5454
val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off"
5555
if (!akkaLogLifecycleEvents) {
@@ -121,4 +121,9 @@ private[spark] object AkkaUtils extends Logging {
121121
def lookupTimeout(conf: SparkConf): FiniteDuration = {
122122
Duration.create(conf.get("spark.akka.lookupTimeout", "30").toLong, "seconds")
123123
}
124+
125+
/** Returns the default max frame size for Akka messages in MB. */
126+
def maxFrameSize(conf: SparkConf): Int = {
127+
conf.getInt("spark.akka.frameSize", 10)
128+
}
124129
}

core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark
2020
import scala.concurrent.Await
2121

2222
import akka.actor._
23+
import akka.testkit.TestActorRef
2324
import org.scalatest.FunSuite
2425

2526
import org.apache.spark.scheduler.MapStatus
@@ -100,7 +101,25 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
100101
}
101102

102103
test("remote fetch") {
103-
val (masterTracker, slaveTracker) = setUpMasterSlaveSystem(conf)
104+
val hostname = "localhost"
105+
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf,
106+
securityManager = new SecurityManager(conf))
107+
108+
// Will be cleared by LocalSparkContext
109+
System.setProperty("spark.driver.port", boundPort.toString)
110+
111+
val masterTracker = new MapOutputTrackerMaster(conf)
112+
masterTracker.trackerActor = actorSystem.actorOf(
113+
Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
114+
115+
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
116+
securityManager = new SecurityManager(conf))
117+
val slaveTracker = new MapOutputTracker(conf)
118+
val selection = slaveSystem.actorSelection(
119+
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
120+
val timeout = AkkaUtils.lookupTimeout(conf)
121+
slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
122+
104123
masterTracker.registerShuffle(10, 1)
105124
masterTracker.incrementEpoch()
106125
slaveTracker.updateEpoch(masterTracker.getEpoch)
@@ -113,7 +132,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
113132
masterTracker.incrementEpoch()
114133
slaveTracker.updateEpoch(masterTracker.getEpoch)
115134
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
116-
Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
135+
Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
117136

118137
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
119138
masterTracker.incrementEpoch()
@@ -128,42 +147,25 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
128147
val newConf = new SparkConf
129148
newConf.set("spark.akka.frameSize", "1")
130149
newConf.set("spark.akka.askTimeout", "1") // Fail fast
131-
val (masterTracker, slaveTracker) = setUpMasterSlaveSystem(newConf)
150+
151+
val masterTracker = new MapOutputTrackerMaster(conf)
152+
val actorSystem = ActorSystem("test")
153+
val actorRef = TestActorRef[MapOutputTrackerMasterActor](
154+
new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem)
155+
val masterActor = actorRef.underlyingActor
132156

133157
// Frame size should be ~123B, and no exception should be thrown
134158
masterTracker.registerShuffle(10, 1)
135159
masterTracker.registerMapOutput(10, 0, new MapStatus(
136160
BlockManagerId("88", "mph", 1000, 0), Array.fill[Byte](10)(0)))
137-
slaveTracker.getServerStatuses(10, 0)
161+
masterActor.receive(GetMapOutputStatuses(10))
138162

139163
// Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception
140164
masterTracker.registerShuffle(20, 100)
141165
(0 until 100).foreach { i =>
142166
masterTracker.registerMapOutput(20, i, new MapStatus(
143167
BlockManagerId("999", "mps", 1000, 0), Array.fill[Byte](4000000)(0)))
144168
}
145-
intercept[SparkException] { slaveTracker.getServerStatuses(20, 0) }
146-
}
147-
148-
private def setUpMasterSlaveSystem(conf: SparkConf) = {
149-
val hostname = "localhost"
150-
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf,
151-
securityManager = new SecurityManager(conf))
152-
153-
// Will be cleared by LocalSparkContext
154-
System.setProperty("spark.driver.port", boundPort.toString)
155-
156-
val masterTracker = new MapOutputTrackerMaster(conf)
157-
masterTracker.trackerActor = actorSystem.actorOf(
158-
Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
159-
160-
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
161-
securityManager = new SecurityManager(conf))
162-
val slaveTracker = new MapOutputTracker(conf)
163-
val selection = slaveSystem.actorSelection(
164-
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
165-
val timeout = AkkaUtils.lookupTimeout(conf)
166-
slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
167-
(masterTracker, slaveTracker)
169+
intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) }
168170
}
169171
}

0 commit comments

Comments
 (0)