@@ -155,9 +155,21 @@ private[spark] class TaskSetManager(
155
155
// Set containing all pending tasks (also used as a stack, as above).
156
156
private val allPendingTasks = new ArrayBuffer [Int ]
157
157
158
- // Tasks that can be speculated. Since these will be a small fraction of total
159
- // tasks, we'll just hold them in a HashSet.
160
- private [scheduler] val speculatableTasks = new HashSet [Int ]
158
+ // Set of pending tasks that can be speculated for each executor.
159
+ private [scheduler] var pendingSpeculatableTasksForExecutor =
160
+ new HashMap [String , ArrayBuffer [Int ]]
161
+
162
+ // Set of pending tasks that can be speculated for each host.
163
+ private [scheduler] var pendingSpeculatableTasksForHost = new HashMap [String , ArrayBuffer [Int ]]
164
+
165
+ // Set of pending tasks that can be speculated with no locality preferences.
166
+ private [scheduler] val pendingSpeculatableTasksWithNoPrefs = new ArrayBuffer [Int ]
167
+
168
+ // Set of pending tasks that can be speculated for each rack.
169
+ private [scheduler] var pendingSpeculatableTasksForRack = new HashMap [String , ArrayBuffer [Int ]]
170
+
171
+ // Set of all pending tasks that can be speculated.
172
+ private [scheduler] val allPendingSpeculatableTasks = new ArrayBuffer [Int ]
161
173
162
174
// Task index, start and finish time for each task attempt (indexed by task ID)
163
175
private [scheduler] val taskInfos = new HashMap [Long , TaskInfo ]
@@ -245,6 +257,28 @@ private[spark] class TaskSetManager(
245
257
allPendingTasks += index // No point scanning this whole list to find the old task there
246
258
}
247
259
260
+ private [spark] def addPendingSpeculativeTask (index : Int ) {
261
+ for (loc <- tasks(index).preferredLocations) {
262
+ loc match {
263
+ case e : ExecutorCacheTaskLocation =>
264
+ pendingSpeculatableTasksForExecutor.getOrElseUpdate(
265
+ e.executorId, new ArrayBuffer ) += index
266
+ case _ =>
267
+ }
268
+ pendingSpeculatableTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer ) += index
269
+ for (rack <- sched.getRackForHost(loc.host)) {
270
+ pendingSpeculatableTasksForRack.getOrElseUpdate(rack, new ArrayBuffer ) += index
271
+ }
272
+ }
273
+
274
+ if (tasks(index).preferredLocations == Nil ) {
275
+ pendingSpeculatableTasksWithNoPrefs += index
276
+ }
277
+
278
+ // No point scanning this whole list to find the old task there
279
+ allPendingSpeculatableTasks += index
280
+ }
281
+
248
282
/**
249
283
* Return the pending tasks list for a given executor ID, or an empty list if
250
284
* there is no map entry for that host
@@ -294,6 +328,30 @@ private[spark] class TaskSetManager(
294
328
None
295
329
}
296
330
331
+ /**
332
+ * Dequeue a pending speculative task from the given list and return its index. Runs similar
333
+ * to the method 'dequeueTaskFromList' with additional constraints. Return None if the
334
+ * list is empty.
335
+ */
336
+ private def dequeueSpeculativeTaskFromList (
337
+ execId : String ,
338
+ host : String ,
339
+ list : ArrayBuffer [Int ]): Option [Int ] = {
340
+ var indexOffset = list.size
341
+ while (indexOffset > 0 ) {
342
+ indexOffset -= 1
343
+ val index = list(indexOffset)
344
+ if (! isTaskBlacklistedOnExecOrNode(index, execId, host) && ! hasAttemptOnHost(index, host)) {
345
+ // This should almost always be list.trimEnd(1) to remove tail
346
+ list.remove(indexOffset)
347
+ if (! successful(index)) {
348
+ return Some (index)
349
+ }
350
+ }
351
+ }
352
+ None
353
+ }
354
+
297
355
/** Check whether a task once ran an attempt on a given host */
298
356
private def hasAttemptOnHost (taskIndex : Int , host : String ): Boolean = {
299
357
taskAttempts(taskIndex).exists(_.host == host)
@@ -315,69 +373,44 @@ private[spark] class TaskSetManager(
315
373
protected def dequeueSpeculativeTask (execId : String , host : String , locality : TaskLocality .Value )
316
374
: Option [(Int , TaskLocality .Value )] =
317
375
{
318
- speculatableTasks.retain(index => ! successful(index)) // Remove finished tasks from set
319
-
320
- def canRunOnHost (index : Int ): Boolean = {
321
- ! hasAttemptOnHost(index, host) &&
322
- ! isTaskBlacklistedOnExecOrNode(index, execId, host)
323
- }
324
-
325
- if (! speculatableTasks.isEmpty) {
326
376
// Check for process-local tasks; note that tasks can be process-local
327
377
// on multiple nodes when we replicate cached blocks, as in Spark Streaming
328
- for (index <- speculatableTasks if canRunOnHost(index)) {
329
- val prefs = tasks(index).preferredLocations
330
- val executors = prefs.flatMap(_ match {
331
- case e : ExecutorCacheTaskLocation => Some (e.executorId)
332
- case _ => None
333
- });
334
- if (executors.contains(execId)) {
335
- speculatableTasks -= index
336
- return Some ((index, TaskLocality .PROCESS_LOCAL ))
337
- }
338
- }
378
+ for (index <- dequeueSpeculativeTaskFromList(
379
+ execId, host, pendingSpeculatableTasksForExecutor.getOrElse(execId, ArrayBuffer ()))) {
380
+ return Some ((index, TaskLocality .PROCESS_LOCAL ))
381
+ }
339
382
340
- // Check for node-local tasks
341
- if (TaskLocality .isAllowed(locality, TaskLocality .NODE_LOCAL )) {
342
- for (index <- speculatableTasks if canRunOnHost(index)) {
343
- val locations = tasks(index).preferredLocations.map(_.host)
344
- if (locations.contains(host)) {
345
- speculatableTasks -= index
346
- return Some ((index, TaskLocality .NODE_LOCAL ))
347
- }
348
- }
383
+ // Check for node-local tasks
384
+ if (TaskLocality .isAllowed(locality, TaskLocality .NODE_LOCAL )) {
385
+ for (index <- dequeueSpeculativeTaskFromList(
386
+ execId, host, pendingSpeculatableTasksForHost.getOrElse(host, ArrayBuffer ()))) {
387
+ return Some ((index, TaskLocality .NODE_LOCAL ))
349
388
}
389
+ }
350
390
351
- // Check for no-preference tasks
352
- if (TaskLocality .isAllowed(locality, TaskLocality .NO_PREF )) {
353
- for (index <- speculatableTasks if canRunOnHost(index)) {
354
- val locations = tasks(index).preferredLocations
355
- if (locations.size == 0 ) {
356
- speculatableTasks -= index
357
- return Some ((index, TaskLocality .PROCESS_LOCAL ))
358
- }
359
- }
391
+ // Check for no-preference tasks
392
+ if (TaskLocality .isAllowed(locality, TaskLocality .NO_PREF )) {
393
+ for (index <- dequeueSpeculativeTaskFromList(
394
+ execId, host, pendingSpeculatableTasksWithNoPrefs)) {
395
+ return Some ((index, TaskLocality .PROCESS_LOCAL ))
360
396
}
397
+ }
361
398
362
- // Check for rack-local tasks
363
- if (TaskLocality .isAllowed(locality, TaskLocality .RACK_LOCAL )) {
364
- for (rack <- sched.getRackForHost(host)) {
365
- for (index <- speculatableTasks if canRunOnHost(index)) {
366
- val racks = tasks(index).preferredLocations.map(_.host).flatMap(sched.getRackForHost)
367
- if (racks.contains(rack)) {
368
- speculatableTasks -= index
369
- return Some ((index, TaskLocality .RACK_LOCAL ))
370
- }
371
- }
372
- }
399
+ // Check for rack-local tasks
400
+ if (TaskLocality .isAllowed(locality, TaskLocality .RACK_LOCAL )) {
401
+ for {
402
+ rack <- sched.getRackForHost(host)
403
+ index <- dequeueSpeculativeTaskFromList(
404
+ execId, host, pendingSpeculatableTasksForRack.getOrElse(rack, ArrayBuffer ()))
405
+ } {
406
+ return Some ((index, TaskLocality .RACK_LOCAL ))
373
407
}
408
+ }
374
409
375
- // Check for non-local tasks
376
- if (TaskLocality .isAllowed(locality, TaskLocality .ANY )) {
377
- for (index <- speculatableTasks if canRunOnHost(index)) {
378
- speculatableTasks -= index
379
- return Some ((index, TaskLocality .ANY ))
380
- }
410
+ // Check for non-local tasks
411
+ if (TaskLocality .isAllowed(locality, TaskLocality .ANY )) {
412
+ for (index <- dequeueSpeculativeTaskFromList(execId, host, allPendingSpeculatableTasks)) {
413
+ return Some ((index, TaskLocality .ANY ))
381
414
}
382
415
}
383
416
@@ -1029,12 +1062,11 @@ private[spark] class TaskSetManager(
1029
1062
for (tid <- runningTasksSet) {
1030
1063
val info = taskInfos(tid)
1031
1064
val index = info.index
1032
- if (! successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
1033
- ! speculatableTasks.contains (index)) {
1065
+ if (! successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold) {
1066
+ addPendingSpeculativeTask (index)
1034
1067
logInfo(
1035
1068
" Marking task %d in stage %s (on %s) as speculatable because it ran more than %.0f ms"
1036
1069
.format(index, taskSet.id, info.host, threshold))
1037
- speculatableTasks += index
1038
1070
sched.dagScheduler.speculativeTaskSubmitted(tasks(index))
1039
1071
foundTasks = true
1040
1072
}
0 commit comments