Skip to content

Commit 3b2e59c

Browse files
committed
Add EventLoop and change DAGScheduler to an EventLoop
1 parent 1e42e96 commit 3b2e59c

File tree

4 files changed

+254
-96
lines changed

4 files changed

+254
-96
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 43 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
1919

2020
import java.io.NotSerializableException
2121
import java.util.Properties
22+
import java.util.concurrent.{TimeUnit, Executors}
2223
import java.util.concurrent.atomic.AtomicInteger
2324

2425
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
@@ -28,8 +29,6 @@ import scala.language.postfixOps
2829
import scala.reflect.ClassTag
2930
import scala.util.control.NonFatal
3031

31-
import akka.actor._
32-
import akka.actor.SupervisorStrategy.Stop
3332
import akka.pattern.ask
3433
import akka.util.Timeout
3534

@@ -39,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
3938
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
4039
import org.apache.spark.rdd.RDD
4140
import org.apache.spark.storage._
42-
import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils}
41+
import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils}
4342
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
4443

4544
/**
@@ -67,8 +66,6 @@ class DAGScheduler(
6766
clock: Clock = SystemClock)
6867
extends Logging {
6968

70-
import DAGScheduler._
71-
7269
def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
7370
this(
7471
sc,
@@ -112,42 +109,31 @@ class DAGScheduler(
112109
// stray messages to detect.
113110
private val failedEpoch = new HashMap[String, Long]
114111

115-
private val dagSchedulerActorSupervisor =
116-
env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))
117-
118112
// A closure serializer that we reuse.
119113
// This is only safe because DAGScheduler runs in a single thread.
120114
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
121115

122-
private[scheduler] var eventProcessActor: ActorRef = _
123116

124117
/** If enabled, we may run certain actions like take() and first() locally. */
125118
private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)
126119

127120
/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
128121
private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
129122

130-
private def initializeEventProcessActor() {
131-
// blocking the thread until supervisor is started, which ensures eventProcessActor is
132-
// not null before any job is submitted
133-
implicit val timeout = Timeout(30 seconds)
134-
val initEventActorReply =
135-
dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this))
136-
eventProcessActor = Await.result(initEventActorReply, timeout.duration).
137-
asInstanceOf[ActorRef]
138-
}
123+
private val messageScheduler =
124+
Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message"))
139125

140-
initializeEventProcessActor()
126+
private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
141127
taskScheduler.setDAGScheduler(this)
142128

143129
// Called by TaskScheduler to report task's starting.
144130
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
145-
eventProcessActor ! BeginEvent(task, taskInfo)
131+
eventProcessLoop.post(BeginEvent(task, taskInfo))
146132
}
147133

148134
// Called to report that a task has completed and results are being fetched remotely.
149135
def taskGettingResult(taskInfo: TaskInfo) {
150-
eventProcessActor ! GettingResultEvent(taskInfo)
136+
eventProcessLoop.post(GettingResultEvent(taskInfo))
151137
}
152138

153139
// Called by TaskScheduler to report task completions or failures.
@@ -158,7 +144,8 @@ class DAGScheduler(
158144
accumUpdates: Map[Long, Any],
159145
taskInfo: TaskInfo,
160146
taskMetrics: TaskMetrics) {
161-
eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
147+
eventProcessLoop.post(
148+
CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
162149
}
163150

164151
/**
@@ -180,18 +167,18 @@ class DAGScheduler(
180167

181168
// Called by TaskScheduler when an executor fails.
182169
def executorLost(execId: String) {
183-
eventProcessActor ! ExecutorLost(execId)
170+
eventProcessLoop.post(ExecutorLost(execId))
184171
}
185172

186173
// Called by TaskScheduler when a host is added
187174
def executorAdded(execId: String, host: String) {
188-
eventProcessActor ! ExecutorAdded(execId, host)
175+
eventProcessLoop.post(ExecutorAdded(execId, host))
189176
}
190177

191178
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
192179
// cancellation of the job itself.
193180
def taskSetFailed(taskSet: TaskSet, reason: String) {
194-
eventProcessActor ! TaskSetFailed(taskSet, reason)
181+
eventProcessLoop.post(TaskSetFailed(taskSet, reason))
195182
}
196183

197184
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
@@ -496,8 +483,8 @@ class DAGScheduler(
496483
assert(partitions.size > 0)
497484
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
498485
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
499-
eventProcessActor ! JobSubmitted(
500-
jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
486+
eventProcessLoop.post(JobSubmitted(
487+
jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties))
501488
waiter
502489
}
503490

@@ -537,8 +524,8 @@ class DAGScheduler(
537524
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
538525
val partitions = (0 until rdd.partitions.size).toArray
539526
val jobId = nextJobId.getAndIncrement()
540-
eventProcessActor ! JobSubmitted(
541-
jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
527+
eventProcessLoop.post(JobSubmitted(
528+
jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
542529
listener.awaitResult() // Will throw an exception if the job fails
543530
}
544531

@@ -547,19 +534,19 @@ class DAGScheduler(
547534
*/
548535
def cancelJob(jobId: Int) {
549536
logInfo("Asked to cancel job " + jobId)
550-
eventProcessActor ! JobCancelled(jobId)
537+
eventProcessLoop.post(JobCancelled(jobId))
551538
}
552539

553540
def cancelJobGroup(groupId: String) {
554541
logInfo("Asked to cancel job group " + groupId)
555-
eventProcessActor ! JobGroupCancelled(groupId)
542+
eventProcessLoop.post(JobGroupCancelled(groupId))
556543
}
557544

558545
/**
559546
* Cancel all jobs that are running or waiting in the queue.
560547
*/
561548
def cancelAllJobs() {
562-
eventProcessActor ! AllJobsCancelled
549+
eventProcessLoop.post(AllJobsCancelled)
563550
}
564551

565552
private[scheduler] def doCancelAllJobs() {
@@ -575,7 +562,7 @@ class DAGScheduler(
575562
* Cancel all jobs associated with a running or scheduled stage.
576563
*/
577564
def cancelStage(stageId: Int) {
578-
eventProcessActor ! StageCancelled(stageId)
565+
eventProcessLoop.post(StageCancelled(stageId))
579566
}
580567

581568
/**
@@ -1059,16 +1046,16 @@ class DAGScheduler(
10591046

10601047
if (disallowStageRetryForTest) {
10611048
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
1062-
} else if (failedStages.isEmpty && eventProcessActor != null) {
1049+
} else if (failedStages.isEmpty && eventProcessLoop != null) {
10631050
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
1064-
// in that case the event will already have been scheduled. eventProcessActor may be
1051+
// in that case the event will already have been scheduled. eventProcessLoop may be
10651052
// null during unit tests.
10661053
// TODO: Cancel running tasks in the stage
1067-
import env.actorSystem.dispatcher
10681054
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
10691055
s"$failedStage (${failedStage.name}) due to fetch failure")
1070-
env.actorSystem.scheduler.scheduleOnce(
1071-
RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages)
1056+
messageScheduler.schedule(new Runnable {
1057+
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
1058+
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
10721059
}
10731060
failedStages += failedStage
10741061
failedStages += mapStage
@@ -1326,40 +1313,21 @@ class DAGScheduler(
13261313

13271314
def stop() {
13281315
logInfo("Stopping DAGScheduler")
1329-
dagSchedulerActorSupervisor ! PoisonPill
1316+
eventProcessLoop.stop()
13301317
taskScheduler.stop()
13311318
}
1332-
}
13331319

1334-
private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler)
1335-
extends Actor with Logging {
1336-
1337-
override val supervisorStrategy =
1338-
OneForOneStrategy() {
1339-
case x: Exception =>
1340-
logError("eventProcesserActor failed; shutting down SparkContext", x)
1341-
try {
1342-
dagScheduler.doCancelAllJobs()
1343-
} catch {
1344-
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
1345-
}
1346-
dagScheduler.sc.stop()
1347-
Stop
1348-
}
1349-
1350-
def receive = {
1351-
case p: Props => sender ! context.actorOf(p)
1352-
case _ => logWarning("received unknown message in DAGSchedulerActorSupervisor")
1353-
}
1320+
// Start the event thread at the end of the constructor
1321+
eventProcessLoop.start()
13541322
}
13551323

1356-
private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler)
1357-
extends Actor with Logging {
1324+
private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler)
1325+
extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging {
13581326

13591327
/**
13601328
* The main event loop of the DAG scheduler.
13611329
*/
1362-
def receive = {
1330+
override def onReceive(event: DAGSchedulerEvent): Unit = event match {
13631331
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
13641332
dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
13651333
listener, properties)
@@ -1398,7 +1366,17 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule
13981366
dagScheduler.resubmitFailedStages()
13991367
}
14001368

1401-
override def postStop() {
1369+
override def onError(e: Throwable): Unit = {
1370+
logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e)
1371+
try {
1372+
dagScheduler.doCancelAllJobs()
1373+
} catch {
1374+
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
1375+
}
1376+
dagScheduler.sc.stop()
1377+
}
1378+
1379+
override def onStop() {
14021380
// Cancel any active jobs in postStop hook
14031381
dagScheduler.cleanUpAfterSchedulerStop()
14041382
}
@@ -1408,9 +1386,5 @@ private[spark] object DAGScheduler {
14081386
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
14091387
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
14101388
// as more failure events come in
1411-
val RESUBMIT_TIMEOUT = 200.milliseconds
1412-
1413-
// The time, in millis, to wake up between polls of the completion queue in order to potentially
1414-
// resubmit failed stages
1415-
val POLL_TIMEOUT = 10L
1389+
val RESUBMIT_TIMEOUT = 200
14161390
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util
19+
20+
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
21+
22+
import scala.util.control.NonFatal
23+
24+
import org.apache.spark.Logging
25+
26+
/**
27+
* An event loop to receive events from the caller and process all events in the event thread. It
28+
* will start an exclusive event thread to process all events.
29+
*/
30+
abstract class EventLoop[E](name: String) extends Logging {
31+
32+
private val eventQueue: BlockingQueue[E] = new LinkedBlockingDeque[E]()
33+
34+
private val eventThread = new Thread(name) {
35+
setDaemon(true)
36+
37+
override def run(): Unit = {
38+
try {
39+
while (true) {
40+
val event = eventQueue.take()
41+
try {
42+
onReceive(event)
43+
} catch {
44+
case NonFatal(e) => {
45+
try {
46+
onError(e)
47+
} catch {
48+
case NonFatal(e) => logError("Unexpected error in " + name, e)
49+
}
50+
}
51+
}
52+
}
53+
} catch {
54+
case ie: InterruptedException => // exit even if eventQueue is not empty
55+
case NonFatal(e) => logError("Unexpected error in " + name, e)
56+
}
57+
}
58+
59+
}
60+
61+
def start(): Unit = {
62+
// Call onStart before starting the event thread to make sure it happens before onReceive
63+
onStart()
64+
eventThread.start()
65+
}
66+
67+
def stop(): Unit = {
68+
eventThread.interrupt()
69+
eventThread.join()
70+
// Call onStop after the event thread exits to make sure onReceive happens before onStop
71+
onStop()
72+
}
73+
74+
/**
75+
* Put the event into the event queue. The event thread will process it later.
76+
*/
77+
def post(event: E): Unit = {
78+
eventQueue.put(event)
79+
}
80+
81+
/**
82+
* Return if the event thread has already been started but not yet stopped.
83+
*/
84+
def isActive: Boolean = eventThread.isAlive
85+
86+
/**
87+
* Invoke when `start()` is called. It's also invoked before the event thread starts.
88+
*/
89+
def onStart(): Unit = {}
90+
91+
/**
92+
* Invoke when `stop()` is called and the event thread exits.
93+
*/
94+
def onStop(): Unit = {}
95+
96+
/**
97+
* Invoke in the event thread when polling events from the event queue.
98+
*
99+
* Note: Should avoid calling blocking actions in `onReceive`, or the event thread will be blocked
100+
* and cannot process events in time. If you want to call some blocking actions, run them in
101+
* another thread.
102+
*/
103+
def onReceive(event: E): Unit
104+
105+
/**
106+
* Invoke if `onReceive` throws any non fatal error. `onError` must not throw any non fatal error.
107+
*/
108+
def onError(e: Throwable): Unit
109+
110+
}

0 commit comments

Comments
 (0)