Skip to content

Commit f025154

Browse files
committed
Merge pull request #2 from kayousterhout/imran_SPARK-8103
Index active task sets by stage Id rather than by task set id
2 parents 19685bb + baf46e1 commit f025154

File tree

3 files changed

+23
-34
lines changed

3 files changed

+23
-34
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,9 @@ private[spark] class TaskSchedulerImpl(
7575

7676
// TaskSetManagers are not thread safe, so any access to one should be synchronized
7777
// on this class.
78-
val activeTaskSets = new HashMap[String, TaskSetManager]
79-
val taskSetsByStage = new HashMap[Int, HashMap[Int, TaskSetManager]]
78+
val stageIdToActiveTaskSet = new HashMap[Int, TaskSetManager]
8079

81-
val taskIdToTaskSetId = new HashMap[Long, String]
80+
val taskIdToStageId = new HashMap[Long, Int]
8281
val taskIdToExecutorId = new HashMap[Long, String]
8382

8483
@volatile private var hasReceivedTask = false
@@ -163,17 +162,13 @@ private[spark] class TaskSchedulerImpl(
163162
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
164163
this.synchronized {
165164
val manager = createTaskSetManager(taskSet, maxTaskFailures)
166-
activeTaskSets(taskSet.id) = manager
167-
val stage = taskSet.stageId
168-
val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
169-
stageTaskSets(taskSet.attempt) = manager
170-
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
171-
ts.taskSet != taskSet && !ts.isZombie
172-
}
173-
if (conflictingTaskSet) {
174-
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
175-
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
165+
stageIdToActiveTaskSet(taskSet.stageId) = manager
166+
val stageId = taskSet.stageId
167+
stageIdToActiveTaskSet.get(stageId).map { activeTaskSet =>
168+
throw new IllegalStateException(
169+
s"Active taskSet with id already exists for stage $stageId: ${activeTaskSet.taskSet.id}")
176170
}
171+
stageIdToActiveTaskSet(stageId) = manager
177172
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
178173

179174
if (!isLocal && !hasReceivedTask) {
@@ -203,7 +198,7 @@ private[spark] class TaskSchedulerImpl(
203198

204199
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
205200
logInfo("Cancelling stage " + stageId)
206-
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
201+
stageIdToActiveTaskSet.get(stageId).map {tsm =>
207202
// There are two possible cases here:
208203
// 1. The task set manager has been created and some tasks have been scheduled.
209204
// In this case, send a kill signal to the executors to kill the task and then abort
@@ -225,13 +220,7 @@ private[spark] class TaskSchedulerImpl(
225220
* cleaned up.
226221
*/
227222
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
228-
activeTaskSets -= manager.taskSet.id
229-
taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
230-
taskSetsForStage -= manager.taskSet.attempt
231-
if (taskSetsForStage.isEmpty) {
232-
taskSetsByStage -= manager.taskSet.stageId
233-
}
234-
}
223+
stageIdToActiveTaskSet -= manager.stageId
235224
manager.parent.removeSchedulable(manager)
236225
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
237226
.format(manager.taskSet.id, manager.parent.name))
@@ -252,7 +241,7 @@ private[spark] class TaskSchedulerImpl(
252241
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
253242
tasks(i) += task
254243
val tid = task.taskId
255-
taskIdToTaskSetId(tid) = taskSet.taskSet.id
244+
taskIdToStageId(tid) = taskSet.taskSet.stageId
256245
taskIdToExecutorId(tid) = execId
257246
executorsByHost(host) += execId
258247
availableCpus(i) -= CPUS_PER_TASK
@@ -336,13 +325,13 @@ private[spark] class TaskSchedulerImpl(
336325
failedExecutor = Some(execId)
337326
}
338327
}
339-
taskIdToTaskSetId.get(tid) match {
340-
case Some(taskSetId) =>
328+
taskIdToStageId.get(tid) match {
329+
case Some(stageId) =>
341330
if (TaskState.isFinished(state)) {
342-
taskIdToTaskSetId.remove(tid)
331+
taskIdToStageId.remove(tid)
343332
taskIdToExecutorId.remove(tid)
344333
}
345-
activeTaskSets.get(taskSetId).foreach { taskSet =>
334+
stageIdToActiveTaskSet.get(stageId).foreach { taskSet =>
346335
if (state == TaskState.FINISHED) {
347336
taskSet.removeRunningTask(tid)
348337
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
@@ -380,8 +369,8 @@ private[spark] class TaskSchedulerImpl(
380369

381370
val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
382371
taskMetrics.flatMap { case (id, metrics) =>
383-
taskIdToTaskSetId.get(id)
384-
.flatMap(activeTaskSets.get)
372+
taskIdToStageId.get(id)
373+
.flatMap(stageIdToActiveTaskSet.get)
385374
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
386375
}
387376
}
@@ -414,9 +403,9 @@ private[spark] class TaskSchedulerImpl(
414403

415404
def error(message: String) {
416405
synchronized {
417-
if (activeTaskSets.nonEmpty) {
406+
if (stageIdToActiveTaskSet.nonEmpty) {
418407
// Have each task set throw a SparkException with the error
419-
for ((taskSetId, manager) <- activeTaskSets) {
408+
for ((_, manager) <- stageIdToActiveTaskSet) {
420409
try {
421410
manager.abort(message)
422411
} catch {

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
191191
for (task <- tasks.flatten) {
192192
val serializedTask = ser.serialize(task)
193193
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
194-
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
195-
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
194+
val taskSetId = scheduler.taskIdToStageId(task.taskId)
195+
scheduler.stageIdToActiveTaskSet.get(taskSetId).foreach { taskSet =>
196196
try {
197197
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
198198
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
144144
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) }
145145

146146
// OK to submit multiple if previous attempts are all zombie
147-
taskScheduler.activeTaskSets(attempt1.id).isZombie = true
147+
taskScheduler.stageIdToActiveTaskSet(attempt1.stageId).isZombie = true
148148
taskScheduler.submitTasks(attempt2)
149149
val attempt3 = new TaskSet(Array(new FakeTask(0)), 0, 2, 0, null)
150150
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) }
151-
taskScheduler.activeTaskSets(attempt2.id).isZombie = true
151+
taskScheduler.stageIdToActiveTaskSet(attempt2.stageId).isZombie = true
152152
taskScheduler.submitTasks(attempt3)
153153
}
154154

0 commit comments

Comments
 (0)