Skip to content

Commit 2d87a62

Browse files
author
pgandhi
committed
[SPARK-26755] : Optimize Spark Scheduler to dequeue speculative tasks more efficiently
Have split the main queue "speculatableTasks" into 5 separate queues based on locality preference similar to how normal tasks are enqueued.
1 parent 8baf3ba commit 2d87a62

File tree

2 files changed

+95
-63
lines changed

2 files changed

+95
-63
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 92 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,21 @@ private[spark] class TaskSetManager(
155155
// Set containing all pending tasks (also used as a stack, as above).
156156
private val allPendingTasks = new ArrayBuffer[Int]
157157

158-
// Tasks that can be speculated. Since these will be a small fraction of total
159-
// tasks, we'll just hold them in a HashSet.
160-
private[scheduler] val speculatableTasks = new HashSet[Int]
158+
// Set of pending tasks that can be speculated for each executor.
159+
private[scheduler] var pendingSpeculatableTasksForExecutor =
160+
new HashMap[String, ArrayBuffer[Int]]
161+
162+
// Set of pending tasks that can be speculated for each host.
163+
private[scheduler] var pendingSpeculatableTasksForHost = new HashMap[String, ArrayBuffer[Int]]
164+
165+
// Set of pending tasks that can be speculated with no locality preferences.
166+
private[scheduler] val pendingSpeculatableTasksWithNoPrefs = new ArrayBuffer[Int]
167+
168+
// Set of pending tasks that can be speculated for each rack.
169+
private[scheduler] var pendingSpeculatableTasksForRack = new HashMap[String, ArrayBuffer[Int]]
170+
171+
// Set of all pending tasks that can be speculated.
172+
private[scheduler] val allPendingSpeculatableTasks = new ArrayBuffer[Int]
161173

162174
// Task index, start and finish time for each task attempt (indexed by task ID)
163175
private[scheduler] val taskInfos = new HashMap[Long, TaskInfo]
@@ -245,6 +257,28 @@ private[spark] class TaskSetManager(
245257
allPendingTasks += index // No point scanning this whole list to find the old task there
246258
}
247259

260+
private[spark] def addPendingSpeculativeTask(index: Int) {
261+
for (loc <- tasks(index).preferredLocations) {
262+
loc match {
263+
case e: ExecutorCacheTaskLocation =>
264+
pendingSpeculatableTasksForExecutor.getOrElseUpdate(
265+
e.executorId, new ArrayBuffer) += index
266+
case _ =>
267+
}
268+
pendingSpeculatableTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer) += index
269+
for (rack <- sched.getRackForHost(loc.host)) {
270+
pendingSpeculatableTasksForRack.getOrElseUpdate(rack, new ArrayBuffer) += index
271+
}
272+
}
273+
274+
if (tasks(index).preferredLocations == Nil) {
275+
pendingSpeculatableTasksWithNoPrefs += index
276+
}
277+
278+
// No point scanning this whole list to find the old task there
279+
allPendingSpeculatableTasks += index
280+
}
281+
248282
/**
249283
* Return the pending tasks list for a given executor ID, or an empty list if
250284
* there is no map entry for that host
@@ -294,6 +328,30 @@ private[spark] class TaskSetManager(
294328
None
295329
}
296330

331+
/**
332+
* Dequeue a pending speculative task from the given list and return its index. Runs similar
333+
* to the method 'dequeueTaskFromList' with additional constraints. Return None if the
334+
* list is empty.
335+
*/
336+
private def dequeueSpeculativeTaskFromList(
337+
execId: String,
338+
host: String,
339+
list: ArrayBuffer[Int]): Option[Int] = {
340+
var indexOffset = list.size
341+
while (indexOffset > 0) {
342+
indexOffset -= 1
343+
val index = list(indexOffset)
344+
if (!isTaskBlacklistedOnExecOrNode(index, execId, host) && !hasAttemptOnHost(index, host)) {
345+
// This should almost always be list.trimEnd(1) to remove tail
346+
list.remove(indexOffset)
347+
if (!successful(index)) {
348+
return Some(index)
349+
}
350+
}
351+
}
352+
None
353+
}
354+
297355
/** Check whether a task once ran an attempt on a given host */
298356
private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
299357
taskAttempts(taskIndex).exists(_.host == host)
@@ -315,69 +373,44 @@ private[spark] class TaskSetManager(
315373
protected def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
316374
: Option[(Int, TaskLocality.Value)] =
317375
{
318-
speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
319-
320-
def canRunOnHost(index: Int): Boolean = {
321-
!hasAttemptOnHost(index, host) &&
322-
!isTaskBlacklistedOnExecOrNode(index, execId, host)
323-
}
324-
325-
if (!speculatableTasks.isEmpty) {
326376
// Check for process-local tasks; note that tasks can be process-local
327377
// on multiple nodes when we replicate cached blocks, as in Spark Streaming
328-
for (index <- speculatableTasks if canRunOnHost(index)) {
329-
val prefs = tasks(index).preferredLocations
330-
val executors = prefs.flatMap(_ match {
331-
case e: ExecutorCacheTaskLocation => Some(e.executorId)
332-
case _ => None
333-
});
334-
if (executors.contains(execId)) {
335-
speculatableTasks -= index
336-
return Some((index, TaskLocality.PROCESS_LOCAL))
337-
}
338-
}
378+
for (index <- dequeueSpeculativeTaskFromList(
379+
execId, host, pendingSpeculatableTasksForExecutor.getOrElse(execId, ArrayBuffer()))) {
380+
return Some((index, TaskLocality.PROCESS_LOCAL))
381+
}
339382

340-
// Check for node-local tasks
341-
if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
342-
for (index <- speculatableTasks if canRunOnHost(index)) {
343-
val locations = tasks(index).preferredLocations.map(_.host)
344-
if (locations.contains(host)) {
345-
speculatableTasks -= index
346-
return Some((index, TaskLocality.NODE_LOCAL))
347-
}
348-
}
383+
// Check for node-local tasks
384+
if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
385+
for (index <- dequeueSpeculativeTaskFromList(
386+
execId, host, pendingSpeculatableTasksForHost.getOrElse(host, ArrayBuffer()))) {
387+
return Some((index, TaskLocality.NODE_LOCAL))
349388
}
389+
}
350390

351-
// Check for no-preference tasks
352-
if (TaskLocality.isAllowed(locality, TaskLocality.NO_PREF)) {
353-
for (index <- speculatableTasks if canRunOnHost(index)) {
354-
val locations = tasks(index).preferredLocations
355-
if (locations.size == 0) {
356-
speculatableTasks -= index
357-
return Some((index, TaskLocality.PROCESS_LOCAL))
358-
}
359-
}
391+
// Check for no-preference tasks
392+
if (TaskLocality.isAllowed(locality, TaskLocality.NO_PREF)) {
393+
for (index <- dequeueSpeculativeTaskFromList(
394+
execId, host, pendingSpeculatableTasksWithNoPrefs)) {
395+
return Some((index, TaskLocality.PROCESS_LOCAL))
360396
}
397+
}
361398

362-
// Check for rack-local tasks
363-
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
364-
for (rack <- sched.getRackForHost(host)) {
365-
for (index <- speculatableTasks if canRunOnHost(index)) {
366-
val racks = tasks(index).preferredLocations.map(_.host).flatMap(sched.getRackForHost)
367-
if (racks.contains(rack)) {
368-
speculatableTasks -= index
369-
return Some((index, TaskLocality.RACK_LOCAL))
370-
}
371-
}
372-
}
399+
// Check for rack-local tasks
400+
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
401+
for {
402+
rack <- sched.getRackForHost(host)
403+
index <- dequeueSpeculativeTaskFromList(
404+
execId, host, pendingSpeculatableTasksForRack.getOrElse(rack, ArrayBuffer()))
405+
} {
406+
return Some((index, TaskLocality.RACK_LOCAL))
373407
}
408+
}
374409

375-
// Check for non-local tasks
376-
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
377-
for (index <- speculatableTasks if canRunOnHost(index)) {
378-
speculatableTasks -= index
379-
return Some((index, TaskLocality.ANY))
380-
}
410+
// Check for non-local tasks
411+
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
412+
for (index <- dequeueSpeculativeTaskFromList(execId, host, allPendingSpeculatableTasks)) {
413+
return Some((index, TaskLocality.ANY))
381414
}
382415
}
383416

@@ -1029,12 +1062,11 @@ private[spark] class TaskSetManager(
10291062
for (tid <- runningTasksSet) {
10301063
val info = taskInfos(tid)
10311064
val index = info.index
1032-
if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
1033-
!speculatableTasks.contains(index)) {
1065+
if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold) {
1066+
addPendingSpeculativeTask(index)
10341067
logInfo(
10351068
"Marking task %d in stage %s (on %s) as speculatable because it ran more than %.0f ms"
10361069
.format(index, taskSet.id, info.host, threshold))
1037-
speculatableTasks += index
10381070
sched.dagScheduler.speculativeTaskSubmitted(tasks(index))
10391071
foundTasks = true
10401072
}

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
724724

725725
// Mark the task as available for speculation, and then offer another resource,
726726
// which should be used to launch a speculative copy of the task.
727-
manager.speculatableTasks += singleTask.partitionId
727+
manager.addPendingSpeculativeTask(singleTask.partitionId)
728728
val task2 = manager.resourceOffer("execB", "host2", TaskLocality.ANY).get
729729

730730
assert(manager.runningTasks === 2)
@@ -869,7 +869,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
869869
assert(manager.resourceOffer("execA", "host1", NODE_LOCAL) == None)
870870
assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index == 1)
871871

872-
manager.speculatableTasks += 1
872+
manager.addPendingSpeculativeTask(1)
873873
clock.advance(LOCALITY_WAIT_MS)
874874
// schedule the nonPref task
875875
assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index === 2)
@@ -1151,7 +1151,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
11511151
// Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be
11521152
// killed, so the FakeTaskScheduler is only told about the successful completion
11531153
// of the speculated task.
1154-
assert(sched.endedTasks(3) === Success)
1154+
assert(sched.endedTasks(4) === Success)
11551155
// also because the scheduler is a mock, our manager isn't notified about the task killed event,
11561156
// so we do that manually
11571157
manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled("test"))

0 commit comments

Comments
 (0)