Skip to content

Commit 37d8f37

Browse files
committed
Added a submitJob interface that returns a Future of the result.
1 parent 1cb42e6 commit 37d8f37

File tree

8 files changed

+185
-87
lines changed

8 files changed

+185
-87
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark
19+
20+
import java.util.concurrent.{ExecutionException, TimeUnit, Future}
21+
22+
import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter}
23+
24+
class FutureJob[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: () => T)
25+
extends Future[T] {
26+
27+
override def isDone: Boolean = jobWaiter.jobFinished
28+
29+
override def cancel(mayInterruptIfRunning: Boolean): Boolean = {
30+
jobWaiter.kill()
31+
true
32+
}
33+
34+
override def isCancelled: Boolean = {
35+
throw new UnsupportedOperationException
36+
}
37+
38+
override def get(): T = {
39+
jobWaiter.awaitResult() match {
40+
case JobSucceeded =>
41+
resultFunc()
42+
case JobFailed(e: Exception, _) =>
43+
throw new ExecutionException(e)
44+
}
45+
}
46+
47+
override def get(timeout: Long, unit: TimeUnit): T = {
48+
throw new UnsupportedOperationException
49+
}
50+
}

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark
2020
import java.io._
2121
import java.net.URI
2222
import java.util.Properties
23+
import java.util.concurrent.Future
2324
import java.util.concurrent.atomic.AtomicInteger
2425

2526
import scala.collection.Map
@@ -812,6 +813,24 @@ class SparkContext(
812813
result
813814
}
814815

816+
def submitJob[T, U, R](
817+
rdd: RDD[T],
818+
processPartition: Iterator[T] => U,
819+
partitionResultHandler: (Int, U) => Unit,
820+
resultFunc: () => R): Future[R] =
821+
{
822+
val callSite = Utils.formatSparkCallSite
823+
val waiter = dagScheduler.submitJob(
824+
rdd,
825+
(context: TaskContext, iter: Iterator[T]) => processPartition(iter),
826+
0 until rdd.partitions.size,
827+
callSite,
828+
allowLocal = false,
829+
partitionResultHandler,
830+
null)
831+
new FutureJob(waiter, resultFunc)
832+
}
833+
815834
/**
816835
* Kill a running job.
817836
*/

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.rdd
1919

2020
import java.util.Random
21+
import java.util.concurrent.Future
2122

2223
import scala.collection.Map
2324
import scala.collection.JavaConversions.mapAsScalaMap
@@ -561,6 +562,15 @@ abstract class RDD[T: ClassManifest](
561562
Array.concat(results: _*)
562563
}
563564

565+
/**
566+
* Return a future for retrieving the results of a collect in an asynchronous fashion.
567+
*/
568+
def collectAsync(): Future[Seq[T]] = {
569+
val results = new ArrayBuffer[T]
570+
sc.submitJob[T, Array[T], Seq[T]](
571+
this, _.toArray, (index, data) => results ++= data, () => results)
572+
}
573+
564574
/**
565575
* Return an array that contains all of the elements in this RDD.
566576
*/

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 68 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,15 @@ class DAGScheduler(
105105

106106
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
107107

108-
val nextJobId = new AtomicInteger(0)
108+
private[scheduler] val nextJobId = new AtomicInteger(0)
109109

110-
val nextStageId = new AtomicInteger(0)
110+
def numTotalJobs: Int = nextJobId.get()
111111

112-
val stageIdToStage = new TimeStampedHashMap[Int, Stage]
112+
private val nextStageId = new AtomicInteger(0)
113113

114-
val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
114+
private val stageIdToStage = new TimeStampedHashMap[Int, Stage]
115+
116+
private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
115117

116118
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
117119

@@ -263,54 +265,50 @@ class DAGScheduler(
263265
}
264266

265267
/**
266-
* Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
267-
* JobWaiter whose getResult() method will return the result of the job when it is complete.
268-
*
269-
* The job is assumed to have at least one partition; zero partition jobs should be handled
270-
* without a JobSubmitted event.
268+
* Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
269+
* can be used to block until the the job finishes executing or can be used to kill the job.
270+
* If the given RDD does not contain any partitions, the function returns None.
271271
*/
272-
private[scheduler] def prepareJob[T, U: ClassManifest](
273-
finalRdd: RDD[T],
272+
def submitJob[T, U](
273+
rdd: RDD[T],
274274
func: (TaskContext, Iterator[T]) => U,
275275
partitions: Seq[Int],
276276
callSite: String,
277277
allowLocal: Boolean,
278278
resultHandler: (Int, U) => Unit,
279-
properties: Properties = null)
280-
: (JobSubmitted, JobWaiter[U]) =
279+
properties: Properties = null): JobWaiter[U] =
281280
{
281+
val jobId = nextJobId.getAndIncrement()
282+
if (partitions.size == 0) {
283+
return new JobWaiter[U](this, jobId, 0, resultHandler)
284+
}
285+
286+
// Check to make sure we are not launching a task on a partition that does not exist.
287+
val maxPartitions = rdd.partitions.length
288+
partitions.find(p => p >= maxPartitions).foreach { p =>
289+
throw new IllegalArgumentException(
290+
"Attempting to access a non-existent partition: " + p + ". " +
291+
"Total number of partitions: " + maxPartitions)
292+
}
293+
282294
assert(partitions.size > 0)
283-
val waiter = new JobWaiter(partitions.size, resultHandler)
284295
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
285-
val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
286-
properties)
287-
(toSubmit, waiter)
296+
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
297+
eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite,
298+
waiter, properties))
299+
waiter
288300
}
289301

290302
def runJob[T, U: ClassManifest](
291-
finalRdd: RDD[T],
303+
rdd: RDD[T],
292304
func: (TaskContext, Iterator[T]) => U,
293305
partitions: Seq[Int],
294306
callSite: String,
295307
allowLocal: Boolean,
296308
resultHandler: (Int, U) => Unit,
297309
properties: Properties = null)
298310
{
299-
if (partitions.size == 0) {
300-
return
301-
}
302-
303-
// Check to make sure we are not launching a task on a partition that does not exist.
304-
val maxPartitions = finalRdd.partitions.length
305-
partitions.find(p => p >= maxPartitions).foreach { p =>
306-
throw new IllegalArgumentException(
307-
"Attempting to access a non-existent partition: " + p + ". " +
308-
"Total number of partitions: " + maxPartitions)
309-
}
310-
311-
val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
312-
finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
313-
eventQueue.put(toSubmit)
311+
val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
314312
waiter.awaitResult() match {
315313
case JobSucceeded => {}
316314
case JobFailed(exception: Exception, _) =>
@@ -331,45 +329,50 @@ class DAGScheduler(
331329
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
332330
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
333331
val partitions = (0 until rdd.partitions.size).toArray
334-
eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
332+
val jobId = nextJobId.getAndIncrement()
333+
eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite,
334+
listener, properties))
335335
listener.awaitResult() // Will throw an exception if the job fails
336336
}
337337

338+
/**
339+
* Kill a job that is running or waiting in the queue.
340+
*/
338341
def killJob(jobId: Int): Unit = this.synchronized {
339342
activeJobs.find(job => job.jobId == jobId).foreach(job => killJob(job))
340-
}
341343

342-
private def killJob(job: ActiveJob): Unit = this.synchronized {
343-
logInfo("Killing Job and cleaning up stages %d".format(job.jobId))
344-
activeJobs.remove(job)
345-
idToActiveJob.remove(job.jobId)
346-
val stage = job.finalStage
347-
resultStageToJob.remove(stage)
348-
killStage(job, stage)
349-
val e = new SparkException("Job killed")
350-
job.listener.jobFailed(e)
351-
listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, None)))
352-
}
353-
354-
private def killStage(job: ActiveJob, stage: Stage): Unit = this.synchronized {
355-
// TODO: Can we reuse taskSetFailed?
356-
logInfo("Killing Stage %s".format(stage.id))
357-
stageIdToStage.remove(stage.id)
358-
if (stage.isShuffleMap) {
359-
shuffleToMapStage.remove(stage.id)
360-
}
361-
waiting.remove(stage)
362-
pendingTasks.remove(stage)
363-
taskSched.killTasks(stage.id)
364-
365-
if (running.contains(stage)) {
366-
running.remove(stage)
344+
def killJob(job: ActiveJob): Unit = this.synchronized {
345+
logInfo("Killing Job and cleaning up stages %d".format(job.jobId))
346+
activeJobs.remove(job)
347+
idToActiveJob.remove(job.jobId)
348+
val stage = job.finalStage
349+
resultStageToJob.remove(stage)
350+
killStage(job, stage)
367351
val e = new SparkException("Job killed")
368-
listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, Some(stage))))
352+
job.listener.jobFailed(e)
353+
listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, None)))
369354
}
370355

371-
stage.parents.foreach(parentStage => killStage(job, parentStage))
372-
//stageToInfos -= stage
356+
def killStage(job: ActiveJob, stage: Stage): Unit = this.synchronized {
357+
// TODO: Can we reuse taskSetFailed?
358+
logInfo("Killing Stage %s".format(stage.id))
359+
stageIdToStage.remove(stage.id)
360+
if (stage.isShuffleMap) {
361+
shuffleToMapStage.remove(stage.id)
362+
}
363+
waiting.remove(stage)
364+
pendingTasks.remove(stage)
365+
taskSched.killTasks(stage.id)
366+
367+
if (running.contains(stage)) {
368+
running.remove(stage)
369+
val e = new SparkException("Job killed")
370+
listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, Some(stage))))
371+
}
372+
373+
stage.parents.foreach(parentStage => killStage(job, parentStage))
374+
//stageToInfos -= stage
375+
}
373376
}
374377

375378
/**
@@ -378,9 +381,8 @@ class DAGScheduler(
378381
*/
379382
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
380383
event match {
381-
case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
382-
val jobId = nextJobId.getAndIncrement()
383-
val finalStage = newStage(finalRDD, None, jobId, Some(callSite))
384+
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
385+
val finalStage = newStage(rdd, None, jobId, Some(callSite))
384386
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
385387
clearCacheLocs()
386388
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +

core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ import org.apache.spark.executor.TaskMetrics
3232
* submitted) but there is a single "logic" thread that reads these events and takes decisions.
3333
* This greatly simplifies synchronization.
3434
*/
35-
private[spark] sealed trait DAGSchedulerEvent
35+
private[scheduler] sealed trait DAGSchedulerEvent
3636

37-
private[spark] case class JobSubmitted(
37+
private[scheduler] case class JobSubmitted(
38+
jobId: Int,
3839
finalRDD: RDD[_],
3940
func: (TaskContext, Iterator[_]) => _,
4041
partitions: Array[Int],
@@ -44,9 +45,10 @@ private[spark] case class JobSubmitted(
4445
properties: Properties = null)
4546
extends DAGSchedulerEvent
4647

47-
private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
48+
private[scheduler]
49+
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
4850

49-
private[spark] case class CompletionEvent(
51+
private[scheduler] case class CompletionEvent(
5052
task: Task[_],
5153
reason: TaskEndReason,
5254
result: Any,
@@ -55,10 +57,12 @@ private[spark] case class CompletionEvent(
5557
taskMetrics: TaskMetrics)
5658
extends DAGSchedulerEvent
5759

58-
private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
60+
private[scheduler]
61+
case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
5962

60-
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
63+
private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
6164

62-
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
65+
private[scheduler]
66+
case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
6367

64-
private[spark] case object StopDAGScheduler extends DAGSchedulerEvent
68+
private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent

core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar
4040
})
4141

4242
metricRegistry.register(MetricRegistry.name("job", "allJobs", "number"), new Gauge[Int] {
43-
override def getValue: Int = dagScheduler.nextJobId.get()
43+
override def getValue: Int = dagScheduler.numTotalJobs
4444
})
4545

4646
metricRegistry.register(MetricRegistry.name("job", "activeJobs", "number"), new Gauge[Int] {

0 commit comments

Comments
 (0)