Skip to content

Commit 435d817

Browse files
committed
[SPARK-979] Randomize order of offers.
This commit randomizes the order of resource offers to avoid scheduling all tasks on the same small set of machines.
1 parent 5a3ad10 commit 435d817

File tree

4 files changed

+75
-41
lines changed

4 files changed

+75
-41
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import scala.concurrent.duration._
2525
import scala.collection.mutable.ArrayBuffer
2626
import scala.collection.mutable.HashMap
2727
import scala.collection.mutable.HashSet
28+
import scala.util.Random
2829

2930
import org.apache.spark._
3031
import org.apache.spark.TaskState.TaskState
@@ -207,9 +208,11 @@ private[spark] class TaskSchedulerImpl(
207208
}
208209
}
209210

210-
// Build a list of tasks to assign to each worker
211-
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
212-
val availableCpus = offers.map(o => o.cores).toArray
211+
// Randomly shuffle offers to avoid always placing tasks on the same set of workers.
212+
val shuffledOffers = Random.shuffle(offers)
213+
// Build a list of tasks to assign to each worker.
214+
val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
215+
val availableCpus = shuffledOffers.map(o => o.cores).toArray
213216
val sortedTaskSets = rootPool.getSortedTaskSetQueue()
214217
for (taskSet <- sortedTaskSets) {
215218
logDebug("parentName: %s, name: %s, runningTasks: %s".format(
@@ -222,9 +225,9 @@ private[spark] class TaskSchedulerImpl(
222225
for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) {
223226
do {
224227
launchedTask = false
225-
for (i <- 0 until offers.size) {
226-
val execId = offers(i).executorId
227-
val host = offers(i).host
228+
for (i <- 0 until shuffledOffers.size) {
229+
val execId = shuffledOffers(i).executorId
230+
val host = shuffledOffers(i).host
228231
for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) {
229232
tasks(i) += task
230233
val tid = task.taskId

core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,19 @@ class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int
2424

2525
override def preferredLocations: Seq[TaskLocation] = prefLocs
2626
}
27+
28+
object FakeTask {
29+
/**
30+
* Utility method to create a TaskSet, potentially setting a particular sequence of preferred
31+
* locations for each task (given as varargs) if this sequence is not empty.
32+
*/
33+
def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
34+
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
35+
throw new IllegalArgumentException("Wrong number of task locations")
36+
}
37+
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
38+
new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
39+
}
40+
new TaskSet(tasks, 0, 0, 0, null)
41+
}
42+
}

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ import org.scalatest.FunSuite
2525

2626
import org.apache.spark._
2727

28+
class FakeSchedulerBackend extends SchedulerBackend {
29+
def start() {}
30+
def stop() {}
31+
def reviveOffers() {}
32+
def defaultParallelism() = 1
33+
}
34+
2835
class FakeTaskSetManager(
2936
initPriority: Int,
3037
initStageId: Int,
@@ -107,7 +114,8 @@ class FakeTaskSetManager(
107114

108115
class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging {
109116

110-
def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, taskSet: TaskSet): FakeTaskSetManager = {
117+
def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl,
118+
taskSet: TaskSet): FakeTaskSetManager = {
111119
new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet)
112120
}
113121

@@ -135,10 +143,7 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin
135143
test("FIFO Scheduler Test") {
136144
sc = new SparkContext("local", "TaskSchedulerImplSuite")
137145
val taskScheduler = new TaskSchedulerImpl(sc)
138-
var tasks = ArrayBuffer[Task[_]]()
139-
val task = new FakeTask(0)
140-
tasks += task
141-
val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
146+
val taskSet = FakeTask.createTaskSet(1)
142147

143148
val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0)
144149
val schedulableBuilder = new FIFOSchedulableBuilder(rootPool)
@@ -162,10 +167,7 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin
162167
test("Fair Scheduler Test") {
163168
sc = new SparkContext("local", "TaskSchedulerImplSuite")
164169
val taskScheduler = new TaskSchedulerImpl(sc)
165-
var tasks = ArrayBuffer[Task[_]]()
166-
val task = new FakeTask(0)
167-
tasks += task
168-
val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
170+
val taskSet = FakeTask.createTaskSet(1)
169171

170172
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
171173
System.setProperty("spark.scheduler.allocation.file", xmlPath)
@@ -219,10 +221,7 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin
219221
test("Nested Pool Test") {
220222
sc = new SparkContext("local", "TaskSchedulerImplSuite")
221223
val taskScheduler = new TaskSchedulerImpl(sc)
222-
var tasks = ArrayBuffer[Task[_]]()
223-
val task = new FakeTask(0)
224-
tasks += task
225-
val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
224+
val taskSet = FakeTask.createTaskSet(1)
226225

227226
val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
228227
val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1)
@@ -265,4 +264,35 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin
265264
checkTaskSetId(rootPool, 6)
266265
checkTaskSetId(rootPool, 2)
267266
}
267+
268+
test("Scheduler does not always schedule tasks on the same workers") {
269+
sc = new SparkContext("local", "TaskSchedulerImplSuite")
270+
val taskScheduler = new TaskSchedulerImpl(sc)
271+
taskScheduler.initialize(new FakeSchedulerBackend)
272+
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
273+
var dagScheduler = new DAGScheduler(taskScheduler) {
274+
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
275+
override def executorGained(execId: String, host: String) {}
276+
}
277+
278+
val numFreeCores = 1
279+
val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores),
280+
new WorkerOffer("executor1", "host1", numFreeCores))
281+
// Repeatedly try to schedule a 1-task job, and make sure that it doesn't always
282+
// get scheduled on the same executor. While there is a chance this test will fail
283+
// because the task randomly gets placed on the first executor all 1000 times, the
284+
// probability of that happening is 2^-1000 (so sufficiently small to be considered
285+
// negligible).
286+
val numTrials = 1000
287+
val selectedExecutorIds = 1.to(numTrials).map { _ =>
288+
val taskSet = FakeTask.createTaskSet(1)
289+
taskScheduler.submitTasks(taskSet)
290+
val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
291+
assert(1 === taskDescriptions.length)
292+
taskDescriptions(0).executorId
293+
}
294+
var count = selectedExecutorIds.count(_ == workerOffers(0).executorId)
295+
assert(count > 0)
296+
assert(count < numTrials)
297+
}
268298
}

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
8888
test("TaskSet with no preferences") {
8989
sc = new SparkContext("local", "test")
9090
val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
91-
val taskSet = createTaskSet(1)
91+
val taskSet = FakeTask.createTaskSet(1)
9292
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
9393

9494
// Offer a host with no CPUs
@@ -114,7 +114,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
114114
test("multiple offers with no preferences") {
115115
sc = new SparkContext("local", "test")
116116
val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
117-
val taskSet = createTaskSet(3)
117+
val taskSet = FakeTask.createTaskSet(3)
118118
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
119119

120120
// First three offers should all find tasks
@@ -145,7 +145,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
145145
test("basic delay scheduling") {
146146
sc = new SparkContext("local", "test")
147147
val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
148-
val taskSet = createTaskSet(4,
148+
val taskSet = FakeTask.createTaskSet(4,
149149
Seq(TaskLocation("host1", "exec1")),
150150
Seq(TaskLocation("host2", "exec2")),
151151
Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")),
@@ -190,7 +190,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
190190
sc = new SparkContext("local", "test")
191191
val sched = new FakeTaskScheduler(sc,
192192
("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
193-
val taskSet = createTaskSet(5,
193+
val taskSet = FakeTask.createTaskSet(5,
194194
Seq(TaskLocation("host1")),
195195
Seq(TaskLocation("host2")),
196196
Seq(TaskLocation("host2")),
@@ -229,7 +229,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
229229
test("delay scheduling with failed hosts") {
230230
sc = new SparkContext("local", "test")
231231
val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
232-
val taskSet = createTaskSet(3,
232+
val taskSet = FakeTask.createTaskSet(3,
233233
Seq(TaskLocation("host1")),
234234
Seq(TaskLocation("host2")),
235235
Seq(TaskLocation("host3"))
@@ -261,7 +261,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
261261
test("task result lost") {
262262
sc = new SparkContext("local", "test")
263263
val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
264-
val taskSet = createTaskSet(1)
264+
val taskSet = FakeTask.createTaskSet(1)
265265
val clock = new FakeClock
266266
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
267267

@@ -278,7 +278,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
278278
test("repeated failures lead to task set abortion") {
279279
sc = new SparkContext("local", "test")
280280
val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
281-
val taskSet = createTaskSet(1)
281+
val taskSet = FakeTask.createTaskSet(1)
282282
val clock = new FakeClock
283283
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
284284

@@ -298,21 +298,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
298298
}
299299
}
300300

301-
302-
/**
303-
* Utility method to create a TaskSet, potentially setting a particular sequence of preferred
304-
* locations for each task (given as varargs) if this sequence is not empty.
305-
*/
306-
def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
307-
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
308-
throw new IllegalArgumentException("Wrong number of task locations")
309-
}
310-
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
311-
new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
312-
}
313-
new TaskSet(tasks, 0, 0, 0, null)
314-
}
315-
316301
def createTaskResult(id: Int): DirectTaskResult[Int] = {
317302
val valueSer = SparkEnv.get.serializer.newInstance()
318303
new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)

0 commit comments

Comments
 (0)