Skip to content

Commit 1dc776b

Browse files
committed
Merge pull request alteryx#93 from kayousterhout/ui_new_state
Show "GETTING_RESULTS" state in UI. This commit adds a set of calls using the SparkListener interface that indicate when a task is remotely fetching results, so that we can display this (potentially time-consuming) phase of execution to users through the UI.
2 parents c4b187d + b45352e commit 1dc776b

File tree

10 files changed

+130
-8
lines changed

10 files changed

+130
-8
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ class DAGScheduler(
6868
eventQueue.put(BeginEvent(task, taskInfo))
6969
}
7070

71+
// Called to report that a task has completed and results are being fetched remotely.
72+
def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
73+
eventQueue.put(GettingResultEvent(task, taskInfo))
74+
}
75+
7176
// Called by TaskScheduler to report task completions or failures.
7277
def taskEnded(
7378
task: Task[_],
@@ -415,6 +420,9 @@ class DAGScheduler(
415420
case begin: BeginEvent =>
416421
listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
417422

423+
case gettingResult: GettingResultEvent =>
424+
listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo))
425+
418426
case completion: CompletionEvent =>
419427
listenerBus.post(SparkListenerTaskEnd(
420428
completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))

core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
5353
private[scheduler]
5454
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
5555

56+
private[scheduler]
57+
case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
58+
5659
private[scheduler] case class CompletionEvent(
5760
task: Task[_],
5861
reason: TaskEndReason,

core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
3131

3232
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
3333

34+
case class SparkListenerTaskGettingResult(
35+
task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
36+
3437
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
3538
taskMetrics: TaskMetrics) extends SparkListenerEvents
3639

@@ -56,6 +59,12 @@ trait SparkListener {
5659
*/
5760
def onTaskStart(taskStart: SparkListenerTaskStart) { }
5861

62+
/**
63+
* Called when a task begins remotely fetching its result (will not be called for tasks that do
64+
* not need to fetch the result remotely).
65+
*/
66+
def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
67+
5968
/**
6069
* Called when a task ends
6170
*/

core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging {
4949
sparkListeners.foreach(_.onJobEnd(jobEnd))
5050
case taskStart: SparkListenerTaskStart =>
5151
sparkListeners.foreach(_.onTaskStart(taskStart))
52+
case taskGettingResult: SparkListenerTaskGettingResult =>
53+
sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
5254
case taskEnd: SparkListenerTaskEnd =>
5355
sparkListeners.foreach(_.onTaskEnd(taskEnd))
5456
case _ =>

core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,25 @@ class TaskInfo(
3131
val host: String,
3232
val taskLocality: TaskLocality.TaskLocality) {
3333

34+
/**
35+
* The time when the task started remotely getting the result. Will not be set if the
36+
* task result was sent immediately when the task finished (as opposed to sending an
37+
* IndirectTaskResult and later fetching the result from the block manager).
38+
*/
39+
var gettingResultTime: Long = 0
40+
41+
/**
42+
* The time when the task has completed successfully (including the time to remotely fetch
43+
* results, if necessary).
44+
*/
3445
var finishTime: Long = 0
46+
3547
var failed = false
3648

49+
def markGettingResult(time: Long = System.currentTimeMillis) {
50+
gettingResultTime = time
51+
}
52+
3753
def markSuccessful(time: Long = System.currentTimeMillis) {
3854
finishTime = time
3955
}
@@ -43,6 +59,8 @@ class TaskInfo(
4359
failed = true
4460
}
4561

62+
def gettingResult: Boolean = gettingResultTime != 0
63+
4664
def finished: Boolean = finishTime != 0
4765

4866
def successful: Boolean = finished && !failed
@@ -52,6 +70,8 @@ class TaskInfo(
5270
def status: String = {
5371
if (running)
5472
"RUNNING"
73+
else if (gettingResult)
74+
"GET RESULT"
5575
else if (failed)
5676
"FAILED"
5777
else if (successful)

core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
306306
}
307307
}
308308

309+
def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
310+
taskSetManager.handleTaskGettingResult(tid)
311+
}
312+
309313
def handleSuccessfulTask(
310314
taskSetManager: ClusterTaskSetManager,
311315
tid: Long,

core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,12 @@ private[spark] class ClusterTaskSetManager(
418418
sched.dagScheduler.taskStarted(task, info)
419419
}
420420

421+
def handleTaskGettingResult(tid: Long) = {
422+
val info = taskInfos(tid)
423+
info.markGettingResult()
424+
sched.dagScheduler.taskGettingResult(tasks(info.index), info)
425+
}
426+
421427
/**
422428
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
423429
*/

core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
5050
case directResult: DirectTaskResult[_] => directResult
5151
case IndirectTaskResult(blockId) =>
5252
logDebug("Fetching indirect task result for TID %s".format(tid))
53+
scheduler.handleTaskGettingResult(taskSetManager, tid)
5354
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
5455
if (!serializedTaskResult.isDefined) {
5556
/* We won't be able to get the task result if the machine that ran the task failed

core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,13 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
115115
taskList += ((taskStart.taskInfo, None, None))
116116
stageIdToTaskInfos(sid) = taskList
117117
}
118-
118+
119+
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult)
120+
= synchronized {
121+
// Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in
122+
// stageToTaskInfos already has the updated status.
123+
}
124+
119125
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
120126
val sid = taskEnd.task.stageId
121127
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())

core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,25 @@
1717

1818
package org.apache.spark.scheduler
1919

20-
import org.scalatest.{BeforeAndAfter, FunSuite}
21-
import org.apache.spark.{LocalSparkContext, SparkContext}
22-
import scala.collection.mutable
20+
import scala.collection.mutable.{Buffer, HashSet}
21+
22+
import org.scalatest.{BeforeAndAfterAll, FunSuite}
2323
import org.scalatest.matchers.ShouldMatchers
24+
25+
import org.apache.spark.{LocalSparkContext, SparkContext}
2426
import org.apache.spark.SparkContext._
2527

2628
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
27-
with BeforeAndAfter {
29+
with BeforeAndAfterAll {
2830
/** Length of time to wait while draining listener events. */
2931
val WAIT_TIMEOUT_MILLIS = 10000
3032

31-
before {
32-
sc = new SparkContext("local", "DAGSchedulerSuite")
33+
override def afterAll {
34+
System.clearProperty("spark.akka.frameSize")
3335
}
3436

3537
test("basic creation of StageInfo") {
38+
sc = new SparkContext("local", "DAGSchedulerSuite")
3639
val listener = new SaveStageInfo
3740
sc.addSparkListener(listener)
3841
val rdd1 = sc.parallelize(1 to 100, 4)
@@ -53,6 +56,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
5356
}
5457

5558
test("StageInfo with fewer tasks than partitions") {
59+
sc = new SparkContext("local", "DAGSchedulerSuite")
5660
val listener = new SaveStageInfo
5761
sc.addSparkListener(listener)
5862
val rdd1 = sc.parallelize(1 to 100, 4)
@@ -68,6 +72,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
6872
}
6973

7074
test("local metrics") {
75+
sc = new SparkContext("local", "DAGSchedulerSuite")
7176
val listener = new SaveStageInfo
7277
sc.addSparkListener(listener)
7378
sc.addSparkListener(new StatsReportListener)
@@ -129,15 +134,73 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
129134
}
130135
}
131136

137+
test("onTaskGettingResult() called when result fetched remotely") {
138+
// Need to use local cluster mode here, because results are not ever returned through the
139+
// block manager when using the LocalScheduler.
140+
sc = new SparkContext("local-cluster[1,1,512]", "test")
141+
142+
val listener = new SaveTaskEvents
143+
sc.addSparkListener(listener)
144+
145+
// Make a task whose result is larger than the akka frame size
146+
System.setProperty("spark.akka.frameSize", "1")
147+
val akkaFrameSize =
148+
sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
149+
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x)
150+
assert(result === 1.to(akkaFrameSize).toArray)
151+
152+
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
153+
val TASK_INDEX = 0
154+
assert(listener.startedTasks.contains(TASK_INDEX))
155+
assert(listener.startedGettingResultTasks.contains(TASK_INDEX))
156+
assert(listener.endedTasks.contains(TASK_INDEX))
157+
}
158+
159+
test("onTaskGettingResult() not called when result sent directly") {
160+
// Need to use local cluster mode here, because results are not ever returned through the
161+
// block manager when using the LocalScheduler.
162+
sc = new SparkContext("local-cluster[1,1,512]", "test")
163+
164+
val listener = new SaveTaskEvents
165+
sc.addSparkListener(listener)
166+
167+
// Make a task whose result is larger than the akka frame size
168+
val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
169+
assert(result === 2)
170+
171+
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
172+
val TASK_INDEX = 0
173+
assert(listener.startedTasks.contains(TASK_INDEX))
174+
assert(listener.startedGettingResultTasks.isEmpty == true)
175+
assert(listener.endedTasks.contains(TASK_INDEX))
176+
}
177+
132178
def checkNonZeroAvg(m: Traversable[Long], msg: String) {
133179
assert(m.sum / m.size.toDouble > 0.0, msg)
134180
}
135181

136182
class SaveStageInfo extends SparkListener {
137-
val stageInfos = mutable.Buffer[StageInfo]()
183+
val stageInfos = Buffer[StageInfo]()
138184
override def onStageCompleted(stage: StageCompleted) {
139185
stageInfos += stage.stage
140186
}
141187
}
142188

189+
class SaveTaskEvents extends SparkListener {
190+
val startedTasks = new HashSet[Int]()
191+
val startedGettingResultTasks = new HashSet[Int]()
192+
val endedTasks = new HashSet[Int]()
193+
194+
override def onTaskStart(taskStart: SparkListenerTaskStart) {
195+
startedTasks += taskStart.taskInfo.index
196+
}
197+
198+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
199+
endedTasks += taskEnd.taskInfo.index
200+
}
201+
202+
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
203+
startedGettingResultTasks += taskGettingResult.taskInfo.index
204+
}
205+
}
143206
}

0 commit comments

Comments
 (0)