@@ -75,10 +75,9 @@ private[spark] class TaskSchedulerImpl(
75
75
76
76
// TaskSetManagers are not thread safe, so any access to one should be synchronized
77
77
// on this class.
78
- val activeTaskSets = new HashMap [String , TaskSetManager ]
79
- val taskSetsByStage = new HashMap [Int , HashMap [Int , TaskSetManager ]]
78
+ val stageIdToActiveTaskSet = new HashMap [Int , TaskSetManager ]
80
79
81
- val taskIdToTaskSetId = new HashMap [Long , String ]
80
+ val taskIdToStageId = new HashMap [Long , Int ]
82
81
val taskIdToExecutorId = new HashMap [Long , String ]
83
82
84
83
@ volatile private var hasReceivedTask = false
@@ -163,17 +162,13 @@ private[spark] class TaskSchedulerImpl(
163
162
logInfo(" Adding task set " + taskSet.id + " with " + tasks.length + " tasks" )
164
163
this .synchronized {
165
164
val manager = createTaskSetManager(taskSet, maxTaskFailures)
166
- activeTaskSets(taskSet.id) = manager
167
- val stage = taskSet.stageId
168
- val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap [Int , TaskSetManager ])
169
- stageTaskSets(taskSet.attempt) = manager
170
- val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
171
- ts.taskSet != taskSet && ! ts.isZombie
172
- }
173
- if (conflictingTaskSet) {
174
- throw new IllegalStateException (s " more than one active taskSet for stage $stage: " +
175
- s " ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(" ," )}" )
165
+ stageIdToActiveTaskSet(taskSet.stageId) = manager
166
+ val stageId = taskSet.stageId
167
+ stageIdToActiveTaskSet.get(stageId).map { activeTaskSet =>
168
+ throw new IllegalStateException (
169
+ s " Active taskSet with id already exists for stage $stageId: ${activeTaskSet.taskSet.id}" )
176
170
}
171
+ stageIdToActiveTaskSet(stageId) = manager
177
172
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
178
173
179
174
if (! isLocal && ! hasReceivedTask) {
@@ -203,7 +198,7 @@ private[spark] class TaskSchedulerImpl(
203
198
204
199
override def cancelTasks (stageId : Int , interruptThread : Boolean ): Unit = synchronized {
205
200
logInfo(" Cancelling stage " + stageId)
206
- activeTaskSets.find(_._2. stageId == stageId).foreach { case (_, tsm) =>
201
+ stageIdToActiveTaskSet.get( stageId).map { tsm =>
207
202
// There are two possible cases here:
208
203
// 1. The task set manager has been created and some tasks have been scheduled.
209
204
// In this case, send a kill signal to the executors to kill the task and then abort
@@ -225,13 +220,7 @@ private[spark] class TaskSchedulerImpl(
225
220
* cleaned up.
226
221
*/
227
222
def taskSetFinished (manager : TaskSetManager ): Unit = synchronized {
228
- activeTaskSets -= manager.taskSet.id
229
- taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
230
- taskSetsForStage -= manager.taskSet.attempt
231
- if (taskSetsForStage.isEmpty) {
232
- taskSetsByStage -= manager.taskSet.stageId
233
- }
234
- }
223
+ stageIdToActiveTaskSet -= manager.stageId
235
224
manager.parent.removeSchedulable(manager)
236
225
logInfo(" Removed TaskSet %s, whose tasks have all completed, from pool %s"
237
226
.format(manager.taskSet.id, manager.parent.name))
@@ -252,7 +241,7 @@ private[spark] class TaskSchedulerImpl(
252
241
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
253
242
tasks(i) += task
254
243
val tid = task.taskId
255
- taskIdToTaskSetId (tid) = taskSet.taskSet.id
244
+ taskIdToStageId (tid) = taskSet.taskSet.stageId
256
245
taskIdToExecutorId(tid) = execId
257
246
executorsByHost(host) += execId
258
247
availableCpus(i) -= CPUS_PER_TASK
@@ -336,13 +325,13 @@ private[spark] class TaskSchedulerImpl(
336
325
failedExecutor = Some (execId)
337
326
}
338
327
}
339
- taskIdToTaskSetId .get(tid) match {
340
- case Some (taskSetId ) =>
328
+ taskIdToStageId .get(tid) match {
329
+ case Some (stageId ) =>
341
330
if (TaskState .isFinished(state)) {
342
- taskIdToTaskSetId .remove(tid)
331
+ taskIdToStageId .remove(tid)
343
332
taskIdToExecutorId.remove(tid)
344
333
}
345
- activeTaskSets .get(taskSetId ).foreach { taskSet =>
334
+ stageIdToActiveTaskSet .get(stageId ).foreach { taskSet =>
346
335
if (state == TaskState .FINISHED ) {
347
336
taskSet.removeRunningTask(tid)
348
337
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
@@ -380,8 +369,8 @@ private[spark] class TaskSchedulerImpl(
380
369
381
370
val metricsWithStageIds : Array [(Long , Int , Int , TaskMetrics )] = synchronized {
382
371
taskMetrics.flatMap { case (id, metrics) =>
383
- taskIdToTaskSetId .get(id)
384
- .flatMap(activeTaskSets .get)
372
+ taskIdToStageId .get(id)
373
+ .flatMap(stageIdToActiveTaskSet .get)
385
374
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
386
375
}
387
376
}
@@ -414,9 +403,9 @@ private[spark] class TaskSchedulerImpl(
414
403
415
404
def error (message : String ) {
416
405
synchronized {
417
- if (activeTaskSets .nonEmpty) {
406
+ if (stageIdToActiveTaskSet .nonEmpty) {
418
407
// Have each task set throw a SparkException with the error
419
- for ((taskSetId , manager) <- activeTaskSets ) {
408
+ for ((_ , manager) <- stageIdToActiveTaskSet ) {
420
409
try {
421
410
manager.abort(message)
422
411
} catch {
0 commit comments