@@ -105,13 +105,15 @@ class DAGScheduler(
105
105
106
106
private val eventQueue = new LinkedBlockingQueue [DAGSchedulerEvent ]
107
107
108
- val nextJobId = new AtomicInteger (0 )
108
+ private [scheduler] val nextJobId = new AtomicInteger (0 )
109
109
110
- val nextStageId = new AtomicInteger ( 0 )
110
+ def numTotalJobs : Int = nextJobId.get( )
111
111
112
- val stageIdToStage = new TimeStampedHashMap [ Int , Stage ]
112
+ private val nextStageId = new AtomicInteger ( 0 )
113
113
114
- val shuffleToMapStage = new TimeStampedHashMap [Int , Stage ]
114
+ private val stageIdToStage = new TimeStampedHashMap [Int , Stage ]
115
+
116
+ private val shuffleToMapStage = new TimeStampedHashMap [Int , Stage ]
115
117
116
118
private [spark] val stageToInfos = new TimeStampedHashMap [Stage , StageInfo ]
117
119
@@ -263,54 +265,50 @@ class DAGScheduler(
263
265
}
264
266
265
267
/**
266
- * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
267
- * JobWaiter whose getResult() method will return the result of the job when it is complete.
268
- *
269
- * The job is assumed to have at least one partition; zero partition jobs should be handled
270
- * without a JobSubmitted event.
268
+ * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
269
+ * can be used to block until the the job finishes executing or can be used to kill the job.
270
+ * If the given RDD does not contain any partitions, the function returns None.
271
271
*/
272
- private [scheduler] def prepareJob [T , U : ClassManifest ](
273
- finalRdd : RDD [T ],
272
+ def submitJob [T , U ](
273
+ rdd : RDD [T ],
274
274
func : (TaskContext , Iterator [T ]) => U ,
275
275
partitions : Seq [Int ],
276
276
callSite : String ,
277
277
allowLocal : Boolean ,
278
278
resultHandler : (Int , U ) => Unit ,
279
- properties : Properties = null )
280
- : (JobSubmitted , JobWaiter [U ]) =
279
+ properties : Properties = null ): JobWaiter [U ] =
281
280
{
281
+ val jobId = nextJobId.getAndIncrement()
282
+ if (partitions.size == 0 ) {
283
+ return new JobWaiter [U ](this , jobId, 0 , resultHandler)
284
+ }
285
+
286
+ // Check to make sure we are not launching a task on a partition that does not exist.
287
+ val maxPartitions = rdd.partitions.length
288
+ partitions.find(p => p >= maxPartitions).foreach { p =>
289
+ throw new IllegalArgumentException (
290
+ " Attempting to access a non-existent partition: " + p + " . " +
291
+ " Total number of partitions: " + maxPartitions)
292
+ }
293
+
282
294
assert(partitions.size > 0 )
283
- val waiter = new JobWaiter (partitions.size, resultHandler)
284
295
val func2 = func.asInstanceOf [(TaskContext , Iterator [_]) => _]
285
- val toSubmit = JobSubmitted (finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
286
- properties)
287
- (toSubmit, waiter)
296
+ val waiter = new JobWaiter (this , jobId, partitions.size, resultHandler)
297
+ eventQueue.put(JobSubmitted (jobId, rdd, func2, partitions.toArray, allowLocal, callSite,
298
+ waiter, properties))
299
+ waiter
288
300
}
289
301
290
302
def runJob [T , U : ClassManifest ](
291
- finalRdd : RDD [T ],
303
+ rdd : RDD [T ],
292
304
func : (TaskContext , Iterator [T ]) => U ,
293
305
partitions : Seq [Int ],
294
306
callSite : String ,
295
307
allowLocal : Boolean ,
296
308
resultHandler : (Int , U ) => Unit ,
297
309
properties : Properties = null )
298
310
{
299
- if (partitions.size == 0 ) {
300
- return
301
- }
302
-
303
- // Check to make sure we are not launching a task on a partition that does not exist.
304
- val maxPartitions = finalRdd.partitions.length
305
- partitions.find(p => p >= maxPartitions).foreach { p =>
306
- throw new IllegalArgumentException (
307
- " Attempting to access a non-existent partition: " + p + " . " +
308
- " Total number of partitions: " + maxPartitions)
309
- }
310
-
311
- val (toSubmit : JobSubmitted , waiter : JobWaiter [_]) = prepareJob(
312
- finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
313
- eventQueue.put(toSubmit)
311
+ val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
314
312
waiter.awaitResult() match {
315
313
case JobSucceeded => {}
316
314
case JobFailed (exception : Exception , _) =>
@@ -331,45 +329,50 @@ class DAGScheduler(
331
329
val listener = new ApproximateActionListener (rdd, func, evaluator, timeout)
332
330
val func2 = func.asInstanceOf [(TaskContext , Iterator [_]) => _]
333
331
val partitions = (0 until rdd.partitions.size).toArray
334
- eventQueue.put(JobSubmitted (rdd, func2, partitions, allowLocal = false , callSite, listener, properties))
332
+ val jobId = nextJobId.getAndIncrement()
333
+ eventQueue.put(JobSubmitted (jobId, rdd, func2, partitions, allowLocal = false , callSite,
334
+ listener, properties))
335
335
listener.awaitResult() // Will throw an exception if the job fails
336
336
}
337
337
338
+ /**
339
+ * Kill a job that is running or waiting in the queue.
340
+ */
338
341
def killJob (jobId : Int ): Unit = this .synchronized {
339
342
activeJobs.find(job => job.jobId == jobId).foreach(job => killJob(job))
340
- }
341
343
342
- private def killJob (job : ActiveJob ): Unit = this .synchronized {
343
- logInfo(" Killing Job and cleaning up stages %d" .format(job.jobId))
344
- activeJobs.remove(job)
345
- idToActiveJob.remove(job.jobId)
346
- val stage = job.finalStage
347
- resultStageToJob.remove(stage)
348
- killStage(job, stage)
349
- val e = new SparkException (" Job killed" )
350
- job.listener.jobFailed(e)
351
- listenerBus.post(SparkListenerJobEnd (job, JobFailed (e, None )))
352
- }
353
-
354
- private def killStage (job : ActiveJob , stage : Stage ): Unit = this .synchronized {
355
- // TODO: Can we reuse taskSetFailed?
356
- logInfo(" Killing Stage %s" .format(stage.id))
357
- stageIdToStage.remove(stage.id)
358
- if (stage.isShuffleMap) {
359
- shuffleToMapStage.remove(stage.id)
360
- }
361
- waiting.remove(stage)
362
- pendingTasks.remove(stage)
363
- taskSched.killTasks(stage.id)
364
-
365
- if (running.contains(stage)) {
366
- running.remove(stage)
344
+ def killJob (job : ActiveJob ): Unit = this .synchronized {
345
+ logInfo(" Killing Job and cleaning up stages %d" .format(job.jobId))
346
+ activeJobs.remove(job)
347
+ idToActiveJob.remove(job.jobId)
348
+ val stage = job.finalStage
349
+ resultStageToJob.remove(stage)
350
+ killStage(job, stage)
367
351
val e = new SparkException (" Job killed" )
368
- listenerBus.post(SparkListenerJobEnd (job, JobFailed (e, Some (stage))))
352
+ job.listener.jobFailed(e)
353
+ listenerBus.post(SparkListenerJobEnd (job, JobFailed (e, None )))
369
354
}
370
355
371
- stage.parents.foreach(parentStage => killStage(job, parentStage))
372
- // stageToInfos -= stage
356
+ def killStage (job : ActiveJob , stage : Stage ): Unit = this .synchronized {
357
+ // TODO: Can we reuse taskSetFailed?
358
+ logInfo(" Killing Stage %s" .format(stage.id))
359
+ stageIdToStage.remove(stage.id)
360
+ if (stage.isShuffleMap) {
361
+ shuffleToMapStage.remove(stage.id)
362
+ }
363
+ waiting.remove(stage)
364
+ pendingTasks.remove(stage)
365
+ taskSched.killTasks(stage.id)
366
+
367
+ if (running.contains(stage)) {
368
+ running.remove(stage)
369
+ val e = new SparkException (" Job killed" )
370
+ listenerBus.post(SparkListenerJobEnd (job, JobFailed (e, Some (stage))))
371
+ }
372
+
373
+ stage.parents.foreach(parentStage => killStage(job, parentStage))
374
+ // stageToInfos -= stage
375
+ }
373
376
}
374
377
375
378
/**
@@ -378,9 +381,8 @@ class DAGScheduler(
378
381
*/
379
382
private [scheduler] def processEvent (event : DAGSchedulerEvent ): Boolean = {
380
383
event match {
381
- case JobSubmitted (finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
382
- val jobId = nextJobId.getAndIncrement()
383
- val finalStage = newStage(finalRDD, None , jobId, Some (callSite))
384
+ case JobSubmitted (jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
385
+ val finalStage = newStage(rdd, None , jobId, Some (callSite))
384
386
val job = new ActiveJob (jobId, finalStage, func, partitions, callSite, listener, properties)
385
387
clearCacheLocs()
386
388
logInfo(" Got job " + job.jobId + " (" + callSite + " ) with " + partitions.length +
0 commit comments