Skip to content

Commit 80e2568

Browse files
squitokayousterhout
authored andcommitted
[SPARK-8103][core] DAGScheduler should not submit multiple concurrent attempts for a stage
https://issues.apache.org/jira/browse/SPARK-8103 cc kayousterhout (thanks for the extra test case) Author: Imran Rashid <[email protected]> Author: Kay Ousterhout <[email protected]> Author: Imran Rashid <[email protected]> Closes apache#6750 from squito/SPARK-8103 and squashes the following commits: fb3acfc [Imran Rashid] fix log msg e01b7aa [Imran Rashid] fix some comments, style 584acd4 [Imran Rashid] simplify going from taskId to taskSetMgr e43ac25 [Imran Rashid] Merge branch 'master' into SPARK-8103 6bc23af [Imran Rashid] update log msg 4470fa1 [Imran Rashid] rename c04707e [Imran Rashid] style 88b61cc [Imran Rashid] add tests to make sure that TaskSchedulerImpl schedules correctly with zombie attempts d7f1ef2 [Imran Rashid] get rid of activeTaskSets a21c8b5 [Imran Rashid] Merge branch 'master' into SPARK-8103 906d626 [Imran Rashid] fix merge 109900e [Imran Rashid] Merge branch 'master' into SPARK-8103 c0d4d90 [Imran Rashid] Revert "Index active task sets by stage Id rather than by task set id" f025154 [Imran Rashid] Merge pull request #2 from kayousterhout/imran_SPARK-8103 baf46e1 [Kay Ousterhout] Index active task sets by stage Id rather than by task set id 19685bb [Imran Rashid] switch to using latestInfo.attemptId, and add comments a5f7c8c [Imran Rashid] remove comment for reviewers 227b40d [Imran Rashid] style 517b6e5 [Imran Rashid] get rid of SparkIllegalStateException b2faef5 [Imran Rashid] faster check for conflicting task sets 6542b42 [Imran Rashid] remove extra stageAttemptId ada7726 [Imran Rashid] reviewer feedback d8eb202 [Imran Rashid] Merge branch 'master' into SPARK-8103 46bc26a [Imran Rashid] more cleanup of debug garbage cb245da [Imran Rashid] finally found the issue ... clean up debug stuff 8c29707 [Imran Rashid] Merge branch 'master' into SPARK-8103 89a59b6 [Imran Rashid] more printlns ... 9601b47 [Imran Rashid] more debug printlns ecb4e7d [Imran Rashid] debugging printlns b6bc248 [Imran Rashid] style 55f4a94 [Imran Rashid] get rid of more random test case since kays tests are clearer 7021d28 [Imran Rashid] update test since listenerBus.waitUntilEmpty now throws an exception instead of returning a boolean 883fe49 [Kay Ousterhout] Unit tests for concurrent stages issue 6e14683 [Imran Rashid] unit test just to make sure we fail fast on concurrent attempts 06a0af6 [Imran Rashid] ignore for jenkins c443def [Imran Rashid] better fix and simpler test case 28d70aa [Imran Rashid] wip on getting a better test case ... a9bf31f [Imran Rashid] wip
1 parent c6fe9b4 commit 80e2568

13 files changed

+383
-86
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,6 @@ class DAGScheduler(
857857
// Get our pending tasks and remember them in our pendingTasks entry
858858
stage.pendingTasks.clear()
859859

860-
861860
// First figure out the indexes of partition ids to compute.
862861
val partitionsToCompute: Seq[Int] = {
863862
stage match {
@@ -918,7 +917,7 @@ class DAGScheduler(
918917
partitionsToCompute.map { id =>
919918
val locs = getPreferredLocs(stage.rdd, id)
920919
val part = stage.rdd.partitions(id)
921-
new ShuffleMapTask(stage.id, taskBinary, part, locs)
920+
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
922921
}
923922

924923
case stage: ResultStage =>
@@ -927,7 +926,7 @@ class DAGScheduler(
927926
val p: Int = job.partitions(id)
928927
val part = stage.rdd.partitions(p)
929928
val locs = getPreferredLocs(stage.rdd, p)
930-
new ResultTask(stage.id, taskBinary, part, locs, id)
929+
new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
931930
}
932931
}
933932
} catch {
@@ -1069,10 +1068,11 @@ class DAGScheduler(
10691068
val execId = status.location.executorId
10701069
logDebug("ShuffleMapTask finished on " + execId)
10711070
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
1072-
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
1071+
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
10731072
} else {
10741073
shuffleStage.addOutputLoc(smt.partitionId, status)
10751074
}
1075+
10761076
if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
10771077
markStageAsFinished(shuffleStage)
10781078
logInfo("looking for newly runnable stages")
@@ -1132,38 +1132,48 @@ class DAGScheduler(
11321132
val failedStage = stageIdToStage(task.stageId)
11331133
val mapStage = shuffleToMapStage(shuffleId)
11341134

1135-
// It is likely that we receive multiple FetchFailed for a single stage (because we have
1136-
// multiple tasks running concurrently on different executors). In that case, it is possible
1137-
// the fetch failure has already been handled by the scheduler.
1138-
if (runningStages.contains(failedStage)) {
1139-
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
1140-
s"due to a fetch failure from $mapStage (${mapStage.name})")
1141-
markStageAsFinished(failedStage, Some(failureMessage))
1142-
}
1135+
if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
1136+
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
1137+
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
1138+
s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
1139+
} else {
11431140

1144-
if (disallowStageRetryForTest) {
1145-
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
1146-
} else if (failedStages.isEmpty) {
1147-
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
1148-
// in that case the event will already have been scheduled.
1149-
// TODO: Cancel running tasks in the stage
1150-
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
1151-
s"$failedStage (${failedStage.name}) due to fetch failure")
1152-
messageScheduler.schedule(new Runnable {
1153-
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
1154-
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
1155-
}
1156-
failedStages += failedStage
1157-
failedStages += mapStage
1158-
// Mark the map whose fetch failed as broken in the map stage
1159-
if (mapId != -1) {
1160-
mapStage.removeOutputLoc(mapId, bmAddress)
1161-
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
1162-
}
1141+
// It is likely that we receive multiple FetchFailed for a single stage (because we have
1142+
// multiple tasks running concurrently on different executors). In that case, it is
1143+
// possible the fetch failure has already been handled by the scheduler.
1144+
if (runningStages.contains(failedStage)) {
1145+
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
1146+
s"due to a fetch failure from $mapStage (${mapStage.name})")
1147+
markStageAsFinished(failedStage, Some(failureMessage))
1148+
} else {
1149+
logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
1150+
s"longer running")
1151+
}
1152+
1153+
if (disallowStageRetryForTest) {
1154+
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
1155+
} else if (failedStages.isEmpty) {
1156+
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
1157+
// in that case the event will already have been scheduled.
1158+
// TODO: Cancel running tasks in the stage
1159+
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
1160+
s"$failedStage (${failedStage.name}) due to fetch failure")
1161+
messageScheduler.schedule(new Runnable {
1162+
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
1163+
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
1164+
}
1165+
failedStages += failedStage
1166+
failedStages += mapStage
1167+
// Mark the map whose fetch failed as broken in the map stage
1168+
if (mapId != -1) {
1169+
mapStage.removeOutputLoc(mapId, bmAddress)
1170+
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
1171+
}
11631172

1164-
// TODO: mark the executor as failed only if there were lots of fetch failures on it
1165-
if (bmAddress != null) {
1166-
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
1173+
// TODO: mark the executor as failed only if there were lots of fetch failures on it
1174+
if (bmAddress != null) {
1175+
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
1176+
}
11671177
}
11681178

11691179
case commitDenied: TaskCommitDenied =>

core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
4141
*/
4242
private[spark] class ResultTask[T, U](
4343
stageId: Int,
44+
stageAttemptId: Int,
4445
taskBinary: Broadcast[Array[Byte]],
4546
partition: Partition,
4647
@transient locs: Seq[TaskLocation],
4748
val outputId: Int)
48-
extends Task[U](stageId, partition.index) with Serializable {
49+
extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
4950

5051
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
5152
if (locs == null) Nil else locs.toSet.toSeq

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
4040
*/
4141
private[spark] class ShuffleMapTask(
4242
stageId: Int,
43+
stageAttemptId: Int,
4344
taskBinary: Broadcast[Array[Byte]],
4445
partition: Partition,
4546
@transient private var locs: Seq[TaskLocation])
46-
extends Task[MapStatus](stageId, partition.index) with Logging {
47+
extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
4748

4849
/** A constructor used only in test suites. This does not require passing in an RDD. */
4950
def this(partitionId: Int) {
50-
this(0, null, new Partition { override def index: Int = 0 }, null)
51+
this(0, 0, null, new Partition { override def index: Int = 0 }, null)
5152
}
5253

5354
@transient private val preferredLocs: Seq[TaskLocation] = {

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ import org.apache.spark.util.Utils
4343
* @param stageId id of the stage this task belongs to
4444
* @param partitionId index of the number in the RDD
4545
*/
46-
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
46+
private[spark] abstract class Task[T](
47+
val stageId: Int,
48+
val stageAttemptId: Int,
49+
var partitionId: Int) extends Serializable {
4750

4851
/**
4952
* The key of the Map is the accumulator id and the value of the Map is the latest accumulator

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(
7575

7676
// TaskSetManagers are not thread safe, so any access to one should be synchronized
7777
// on this class.
78-
val activeTaskSets = new HashMap[String, TaskSetManager]
78+
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
7979

80-
val taskIdToTaskSetId = new HashMap[Long, String]
80+
private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
8181
val taskIdToExecutorId = new HashMap[Long, String]
8282

8383
@volatile private var hasReceivedTask = false
@@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl(
162162
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
163163
this.synchronized {
164164
val manager = createTaskSetManager(taskSet, maxTaskFailures)
165-
activeTaskSets(taskSet.id) = manager
165+
val stage = taskSet.stageId
166+
val stageTaskSets =
167+
taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
168+
stageTaskSets(taskSet.stageAttemptId) = manager
169+
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
170+
ts.taskSet != taskSet && !ts.isZombie
171+
}
172+
if (conflictingTaskSet) {
173+
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
174+
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
175+
}
166176
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
167177

168178
if (!isLocal && !hasReceivedTask) {
@@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl(
192202

193203
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
194204
logInfo("Cancelling stage " + stageId)
195-
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
196-
// There are two possible cases here:
197-
// 1. The task set manager has been created and some tasks have been scheduled.
198-
// In this case, send a kill signal to the executors to kill the task and then abort
199-
// the stage.
200-
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
201-
// simply abort the stage.
202-
tsm.runningTasksSet.foreach { tid =>
203-
val execId = taskIdToExecutorId(tid)
204-
backend.killTask(tid, execId, interruptThread)
205+
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
206+
attempts.foreach { case (_, tsm) =>
207+
// There are two possible cases here:
208+
// 1. The task set manager has been created and some tasks have been scheduled.
209+
// In this case, send a kill signal to the executors to kill the task and then abort
210+
// the stage.
211+
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
212+
// simply abort the stage.
213+
tsm.runningTasksSet.foreach { tid =>
214+
val execId = taskIdToExecutorId(tid)
215+
backend.killTask(tid, execId, interruptThread)
216+
}
217+
tsm.abort("Stage %s cancelled".format(stageId))
218+
logInfo("Stage %d was cancelled".format(stageId))
205219
}
206-
tsm.abort("Stage %s cancelled".format(stageId))
207-
logInfo("Stage %d was cancelled".format(stageId))
208220
}
209221
}
210222

@@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl(
214226
* cleaned up.
215227
*/
216228
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
217-
activeTaskSets -= manager.taskSet.id
229+
taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
230+
taskSetsForStage -= manager.taskSet.stageAttemptId
231+
if (taskSetsForStage.isEmpty) {
232+
taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
233+
}
234+
}
218235
manager.parent.removeSchedulable(manager)
219236
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
220237
.format(manager.taskSet.id, manager.parent.name))
@@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
235252
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
236253
tasks(i) += task
237254
val tid = task.taskId
238-
taskIdToTaskSetId(tid) = taskSet.taskSet.id
255+
taskIdToTaskSetManager(tid) = taskSet
239256
taskIdToExecutorId(tid) = execId
240257
executorsByHost(host) += execId
241258
availableCpus(i) -= CPUS_PER_TASK
@@ -319,26 +336,24 @@ private[spark] class TaskSchedulerImpl(
319336
failedExecutor = Some(execId)
320337
}
321338
}
322-
taskIdToTaskSetId.get(tid) match {
323-
case Some(taskSetId) =>
339+
taskIdToTaskSetManager.get(tid) match {
340+
case Some(taskSet) =>
324341
if (TaskState.isFinished(state)) {
325-
taskIdToTaskSetId.remove(tid)
342+
taskIdToTaskSetManager.remove(tid)
326343
taskIdToExecutorId.remove(tid)
327344
}
328-
activeTaskSets.get(taskSetId).foreach { taskSet =>
329-
if (state == TaskState.FINISHED) {
330-
taskSet.removeRunningTask(tid)
331-
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
332-
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
333-
taskSet.removeRunningTask(tid)
334-
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
335-
}
345+
if (state == TaskState.FINISHED) {
346+
taskSet.removeRunningTask(tid)
347+
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
348+
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
349+
taskSet.removeRunningTask(tid)
350+
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
336351
}
337352
case None =>
338353
logError(
339354
("Ignoring update with state %s for TID %s because its task set is gone (this is " +
340-
"likely the result of receiving duplicate task finished status updates)")
341-
.format(state, tid))
355+
"likely the result of receiving duplicate task finished status updates)")
356+
.format(state, tid))
342357
}
343358
} catch {
344359
case e: Exception => logError("Exception in statusUpdate", e)
@@ -363,9 +378,9 @@ private[spark] class TaskSchedulerImpl(
363378

364379
val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
365380
taskMetrics.flatMap { case (id, metrics) =>
366-
taskIdToTaskSetId.get(id)
367-
.flatMap(activeTaskSets.get)
368-
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
381+
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
382+
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
383+
}
369384
}
370385
}
371386
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
@@ -397,9 +412,12 @@ private[spark] class TaskSchedulerImpl(
397412

398413
def error(message: String) {
399414
synchronized {
400-
if (activeTaskSets.nonEmpty) {
415+
if (taskSetsByStageIdAndAttempt.nonEmpty) {
401416
// Have each task set throw a SparkException with the error
402-
for ((taskSetId, manager) <- activeTaskSets) {
417+
for {
418+
attempts <- taskSetsByStageIdAndAttempt.values
419+
manager <- attempts.values
420+
} {
403421
try {
404422
manager.abort(message)
405423
} catch {
@@ -520,6 +538,17 @@ private[spark] class TaskSchedulerImpl(
520538

521539
override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()
522540

541+
private[scheduler] def taskSetManagerForAttempt(
542+
stageId: Int,
543+
stageAttemptId: Int): Option[TaskSetManager] = {
544+
for {
545+
attempts <- taskSetsByStageIdAndAttempt.get(stageId)
546+
manager <- attempts.get(stageAttemptId)
547+
} yield {
548+
manager
549+
}
550+
}
551+
523552
}
524553

525554

core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ import java.util.Properties
2626
private[spark] class TaskSet(
2727
val tasks: Array[Task[_]],
2828
val stageId: Int,
29-
val attempt: Int,
29+
val stageAttemptId: Int,
3030
val priority: Int,
3131
val properties: Properties) {
32-
val id: String = stageId + "." + attempt
32+
val id: String = stageId + "." + stageAttemptId
3333

3434
override def toString: String = "TaskSet " + id
3535
}

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
191191
for (task <- tasks.flatten) {
192192
val serializedTask = ser.serialize(task)
193193
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
194-
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
195-
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
194+
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
196195
try {
197196
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
198197
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
199198
"spark.akka.frameSize or using broadcast variables for large values."
200199
msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
201200
AkkaUtils.reservedSizeBytes)
202-
taskSet.abort(msg)
201+
taskSetMgr.abort(msg)
203202
} catch {
204203
case e: Exception => logError("Exception in error callback", e)
205204
}

0 commit comments

Comments
 (0)