Skip to content

Commit 5d8b156

Browse files
committed
Changes based on Kay's review.
1 parent 9f18bad commit 5d8b156

File tree

10 files changed

+98
-69
lines changed

10 files changed

+98
-69
lines changed

core/src/main/scala/org/apache/spark/Accumulators.scala

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,28 +36,30 @@ import org.apache.spark.serializer.JavaSerializer
3636
*
3737
* @param initialValue initial value of accumulator
3838
* @param param helper object defining how to add elements of type `R` and `T`
39+
* @param _name human-readable name for use in Spark's web UI
40+
* @param display whether to show accumulator values Spark's web UI
3941
* @tparam R the full accumulated data (result type)
4042
* @tparam T partial data that can be added in
4143
*/
4244
class Accumulable[R, T] (
4345
@transient initialValue: R,
44-
param: AccumulableParam[R, T])
46+
param: AccumulableParam[R, T],
47+
_name: Option[String],
48+
val display: Boolean)
4549
extends Serializable {
4650

47-
val id = Accumulators.newId
51+
def this(@transient initialValue: R, param: AccumulableParam[R, T]) =
52+
this(initialValue, param, None, true)
53+
54+
val id: Long = Accumulators.newId
55+
val name = _name.getOrElse(s"accumulator_$id")
56+
4857
@transient private var value_ = initialValue // Current value on master
4958
val zero = param.zero(initialValue) // Zero value to be passed to workers
5059
private var deserialized = false
5160

5261
Accumulators.register(this, true)
5362

54-
/** A name for this accumulator / accumulable for display in Spark's UI.
55-
* Note that names must be unique within a SparkContext. */
56-
def name: String = s"accumulator_$id"
57-
58-
/** Whether to display this accumulator in the web UI. */
59-
def display: Boolean = true
60-
6163
/**
6264
* Add more data to this accumulator / accumulable
6365
* @param term the data to add
@@ -97,6 +99,16 @@ class Accumulable[R, T] (
9799
}
98100
}
99101

102+
/**
103+
* Function to customize printing values of this accumulator.
104+
*/
105+
def prettyValue(_value: R) = s"$value"
106+
107+
/**
108+
* Function to customize printing partially accumulated (local) values of this accumulator.
109+
*/
110+
def prettyPartialValue(_value: R) = prettyValue(_value)
111+
100112
/**
101113
* Get the current value of this accumulator from within a task.
102114
*
@@ -226,11 +238,9 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
226238
* @param param helper object defining how to add elements of type `T`
227239
* @tparam T result type
228240
*/
229-
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], _name: String,
230-
_display: Boolean) extends Accumulable[T,T](initialValue, param) {
231-
override def name = if (_name.eq(null)) s"accumulator_$id" else _name
232-
override def display = _display
233-
def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, null, true)
241+
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String],
242+
display: Boolean) extends Accumulable[T,T](initialValue, param, name, display) {
243+
def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None, true)
234244
}
235245

236246
/**

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -758,13 +758,12 @@ class SparkContext(config: SparkConf) extends Logging {
758758
new Accumulator(initialValue, param)
759759

760760
/**
761-
* Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add"
762-
* values to using the `+=` method. Only the driver can access the accumulator's `value`.
763-
*
764-
* This version adds a custom name to the accumulator for display in the Spark UI.
761+
* Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display
762+
* in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the
763+
* driver can access the accumulator's `value`.
765764
*/
766765
def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = {
767-
new Accumulator(initialValue, param, name, true)
766+
new Accumulator(initialValue, param, Some(name), true)
768767
}
769768

770769
/**

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -818,14 +818,15 @@ class DAGScheduler(
818818
// TODO: fail the stage if the accumulator update fails...
819819
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
820820
event.accumUpdates.foreach { case (id, partialValue) =>
821-
val acc = Accumulators.originals(id)
821+
val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]]
822822
val name = acc.name
823823
// To avoid UI cruft, ignore cases where value wasn't updated
824824
if (partialValue != acc.zero) {
825-
val stringPartialValue = s"${partialValue}"
826-
val stringValue = s"${acc.value}"
827-
stageToInfos(stage).accumulatedValues(name) = stringValue
828-
event.taskInfo.accumulableValues += ((name, stringPartialValue))
825+
val stringPartialValue = acc.prettyPartialValue(partialValue)
826+
val stringValue = acc.prettyValue(acc.value)
827+
stageToInfos(stage).accumulables(id) = AccumulableInfo(id, acc.name, stringValue)
828+
event.taskInfo.accumulables +=
829+
AccumulableInfo(id, name, Some(stringPartialValue), stringValue)
829830
}
830831
}
831832
}

core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.scheduler
1919

2020
import scala.collection.mutable.HashMap
21-
import scala.collection.mutable.Map
2221

2322
import org.apache.spark.annotation.DeveloperApi
2423
import org.apache.spark.storage.RDDInfo
@@ -40,8 +39,8 @@ class StageInfo(
4039
var completionTime: Option[Long] = None
4140
/** If the stage failed, the reason why. */
4241
var failureReason: Option[String] = None
43-
/** Terminal values of accumulables updated during this stage. */
44-
val accumulatedValues: Map[String, String] = HashMap[String, String]()
42+
/** Terminal values of accumulables updated during this stage.*/
43+
val accumulables = HashMap[Long, AccumulableInfo]()
4544

4645
def stageFailed(reason: String) {
4746
failureReason = Some(reason)

core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@ class TaskInfo(
4545

4646
/**
4747
* Intermediate updates to accumulables during this task. Note that it is valid for the same
48-
* accumulable to be updated multiple times in a single task.
48+
* accumulable to be updated multiple times in a single task or for two accumulables with the
49+
* same name but different ID's to exist in a task.
4950
*/
50-
val accumulableValues = ListBuffer[(String, String)]()
51+
val accumulables = ListBuffer[AccumulableInfo]()
5152

5253
/**
5354
* The time when the task has completed successfully (including the time to remotely fetch

core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener {
4848

4949
// TODO: Should probably consolidate all following into a single hash map.
5050
val stageIdToTime = HashMap[Int, Long]()
51-
val stageIdToAccumulables = HashMap[Int, Map[String, String]]()
51+
val stageIdToAccumulables = HashMap[Int, Map[Long, AccumulableInfo]]()
5252
val stageIdToInputBytes = HashMap[Int, Long]()
5353
val stageIdToShuffleRead = HashMap[Int, Long]()
5454
val stageIdToShuffleWrite = HashMap[Int, Long]()
@@ -75,9 +75,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener {
7575
// Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage
7676
poolToActiveStages(stageIdToPool(stageId)).remove(stageId)
7777

78-
val accumulables = stageIdToAccumulables.getOrElseUpdate(stageId, HashMap[String, String]())
79-
for ((name, value) <- stageCompleted.stageInfo.accumulatedValues) {
80-
accumulables(name) = value
78+
val emptyMap = HashMap[Long, AccumulableInfo]()
79+
val accumulables = stageIdToAccumulables.getOrElseUpdate(stageId, emptyMap)
80+
for ((id, info) <- stageCompleted.stageInfo.accumulables) {
81+
accumulables(id) = info
8182
}
8283

8384
activeStages.remove(stageId)
@@ -155,9 +156,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener {
155156
val info = taskEnd.taskInfo
156157

157158
if (info != null) {
158-
val accumulables = stageIdToAccumulables.getOrElseUpdate(sid, HashMap[String, String]())
159-
for ((name, value) <- info.accumulableValues) {
160-
accumulables(name) = value
159+
val emptyMap = HashMap[Long, AccumulableInfo]()
160+
val accumulables = stageIdToAccumulables.getOrElseUpdate(sid, emptyMap)
161+
for (accumulableInfo <- info.accumulables) {
162+
accumulables(accumulableInfo.id) = accumulableInfo
161163
}
162164

163165
// create executor summary map if necessary

core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ package org.apache.spark.ui.jobs
2020
import java.util.Date
2121
import javax.servlet.http.HttpServletRequest
2222

23-
import scala.xml.{Unparsed, Node}
23+
import scala.xml.{Node, Unparsed}
2424

2525
import org.apache.spark.ui.{WebUIPage, UIUtils}
2626
import org.apache.spark.util.{Utils, Distribution}
27+
import org.apache.spark.scheduler.AccumulableInfo
2728

2829
/** Page showing statistics and task list for a given stage */
2930
private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
@@ -104,9 +105,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
104105
</div>
105106
// scalastyle:on
106107
val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value")
107-
def accumulableRow(acc: (String, String)) = <tr><td>{acc._1}</td><td>{acc._2}</td></tr>
108+
def accumulableRow(acc: AccumulableInfo) = <tr><td>{acc.name}</td><td>{acc.value}</td></tr>
108109
val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow,
109-
accumulables.toSeq)
110+
accumulables.values.toSeq)
110111

111112
val taskHeaders: Seq[String] =
112113
Seq(
@@ -291,7 +292,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
291292
{if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""}
292293
</td>
293294
<td>
294-
{Unparsed(info.accumulableValues.map{ case (k, v) => s"$k: $v" }.mkString("<br/>"))}
295+
{Unparsed(
296+
info.accumulables.map{acc => s"${acc.name}: ${acc.update.get}"}.mkString("<br/>")
297+
)}
295298
</td>
296299
<!--
297300
TODO: Add this back after we add support to hide certain columns.

core/src/main/scala/org/apache/spark/util/JsonProtocol.scala

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,12 @@ private[spark] object JsonProtocol {
191191
("Submission Time" -> submissionTime) ~
192192
("Completion Time" -> completionTime) ~
193193
("Failure Reason" -> failureReason) ~
194-
("Accumulated Values" -> mapToJson(stageInfo.accumulatedValues))
194+
("Accumulables" -> JArray(
195+
stageInfo.accumulables.values.map(accumulableInfoToJson).toList))
195196
}
196197

197198
def taskInfoToJson(taskInfo: TaskInfo): JValue = {
198-
val accumUpdateMap = taskInfo.accumulableValues.map { case (k, v) =>
199-
mapToJson(Map(k -> v))
200-
}.toList
199+
val accumUpdateMap = taskInfo.accumulables
201200
("Task ID" -> taskInfo.taskId) ~
202201
("Index" -> taskInfo.index) ~
203202
("Attempt" -> taskInfo.attempt) ~
@@ -209,7 +208,14 @@ private[spark] object JsonProtocol {
209208
("Getting Result Time" -> taskInfo.gettingResultTime) ~
210209
("Finish Time" -> taskInfo.finishTime) ~
211210
("Failed" -> taskInfo.failed) ~
212-
("Accumulable Updates" -> JArray(accumUpdateMap))
211+
("Accumulables" -> JArray(taskInfo.accumulables.map(accumulableInfoToJson).toList))
212+
}
213+
214+
def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = {
215+
("ID" -> accumulableInfo.id) ~
216+
("Name" -> accumulableInfo.name) ~
217+
("Update" -> accumulableInfo.update.map(new JString(_)).getOrElse(JNothing)) ~
218+
("Value" -> accumulableInfo.value)
213219
}
214220

215221
def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = {
@@ -485,21 +491,22 @@ private[spark] object JsonProtocol {
485491
val stageId = (json \ "Stage ID").extract[Int]
486492
val stageName = (json \ "Stage Name").extract[String]
487493
val numTasks = (json \ "Number of Tasks").extract[Int]
488-
val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson)
494+
val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson(_))
489495
val details = (json \ "Details").extractOpt[String].getOrElse("")
490496
val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long])
491497
val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long])
492498
val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String])
493-
val accumulatedValues = (json \ "Accumulated Values").extractOpt[JObject].map(mapFromJson(_))
499+
val accumulatedValues = (json \ "Accumulables").extractOpt[List[JValue]] match {
500+
case Some(values) => values.map(accumulableInfoFromJson(_))
501+
case None => Seq[AccumulableInfo]()
502+
}
494503

495504
val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfos, details)
496505
stageInfo.submissionTime = submissionTime
497506
stageInfo.completionTime = completionTime
498507
stageInfo.failureReason = failureReason
499-
accumulatedValues.foreach { values =>
500-
for ((k, v) <- values) {
501-
stageInfo.accumulatedValues(k) = v
502-
}
508+
for (accInfo <- accumulatedValues) {
509+
stageInfo.accumulables(accInfo.id) = accInfo
503510
}
504511
stageInfo
505512
}
@@ -516,22 +523,28 @@ private[spark] object JsonProtocol {
516523
val gettingResultTime = (json \ "Getting Result Time").extract[Long]
517524
val finishTime = (json \ "Finish Time").extract[Long]
518525
val failed = (json \ "Failed").extract[Boolean]
519-
val accumulableUpdates = (json \ "Accumulable Updates").extractOpt[Seq[JValue]].map(
520-
updates => updates.map(mapFromJson(_)))
526+
val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match {
527+
case Some(values) => values.map(accumulableInfoFromJson(_))
528+
case None => Seq[AccumulableInfo]()
529+
}
521530

522531
val taskInfo =
523532
new TaskInfo(taskId, index, attempt, launchTime, executorId, host, taskLocality, speculative)
524533
taskInfo.gettingResultTime = gettingResultTime
525534
taskInfo.finishTime = finishTime
526535
taskInfo.failed = failed
527-
accumulableUpdates.foreach { maps =>
528-
for (m <- maps) {
529-
taskInfo.accumulableValues += m.head
530-
}
531-
}
536+
accumulables.foreach { taskInfo.accumulables += _ }
532537
taskInfo
533538
}
534539

540+
def accumulableInfoFromJson(json: JValue): AccumulableInfo = {
541+
val id = (json \ "id").extract[Long]
542+
val name = (json \ "name").extract[String]
543+
val update = Utils.jsonOption(json \ "update").map(_.extract[String])
544+
val value = (json \ "value").extract[String]
545+
AccumulableInfo(id, name, update, value)
546+
}
547+
535548
def taskMetricsFromJson(json: JValue): TaskMetrics = {
536549
if (json == JNothing) {
537550
return TaskMetrics.empty

core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class JsonProtocolSuite extends FunSuite {
129129

130130
// Fields added after 1.0.0.
131131
assert(info.details.nonEmpty)
132-
assert(info.accumulatedValues.nonEmpty)
132+
assert(info.accumulables.nonEmpty)
133133
val oldJson = newJson
134134
.removeField { case (field, _) => field == "Details" }
135135
.removeField { case (field, _) => field == "Accumulated Values" }
@@ -138,7 +138,7 @@ class JsonProtocolSuite extends FunSuite {
138138

139139
assert(info.name === newInfo.name)
140140
assert("" === newInfo.details)
141-
assert(0 === newInfo.accumulatedValues.size)
141+
assert(0 === newInfo.accumulables.size)
142142
}
143143

144144
test("InputMetrics backward compatibility") {
@@ -268,7 +268,7 @@ class JsonProtocolSuite extends FunSuite {
268268
(0 until info1.rddInfos.size).foreach { i =>
269269
assertEquals(info1.rddInfos(i), info2.rddInfos(i))
270270
}
271-
assert(info1.accumulatedValues === info2.accumulatedValues)
271+
assert(info1.accumulables === info2.accumulables)
272272
assert(info1.details === info2.details)
273273
}
274274

@@ -301,7 +301,7 @@ class JsonProtocolSuite extends FunSuite {
301301
assert(info1.gettingResultTime === info2.gettingResultTime)
302302
assert(info1.finishTime === info2.finishTime)
303303
assert(info1.failed === info2.failed)
304-
assert(info1.accumulableValues === info2.accumulableValues)
304+
assert(info1.accumulables === info2.accumulables)
305305
}
306306

307307
private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) {
@@ -487,17 +487,17 @@ class JsonProtocolSuite extends FunSuite {
487487
private def makeStageInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = {
488488
val rddInfos = (0 until a % 5).map { i => makeRddInfo(a + i, b + i, c + i, d + i, e + i) }
489489
val stageInfo = new StageInfo(a, "greetings", b, rddInfos, "details")
490-
stageInfo.accumulatedValues("acc1") = "val1"
491-
stageInfo.accumulatedValues("acc2") = "val2"
490+
stageInfo.accumulables("acc1") = "val1"
491+
stageInfo.accumulables("acc2") = "val2"
492492
stageInfo
493493
}
494494

495495
private def makeTaskInfo(a: Long, b: Int, c: Int, d: Long, speculative: Boolean) = {
496496
val taskInfo = new TaskInfo(a, b, c, d, "executor", "your kind sir", TaskLocality.NODE_LOCAL,
497497
speculative)
498-
taskInfo.accumulableValues += (("acc1", "val1"))
499-
taskInfo.accumulableValues += (("acc1", "val1"))
500-
taskInfo.accumulableValues += (("acc2", "val2"))
498+
taskInfo.accumulables += (("acc1", "val1"))
499+
taskInfo.accumulables += (("acc1", "val1"))
500+
taskInfo.accumulables += (("acc2", "val2"))
501501
taskInfo
502502
}
503503

docs/programming-guide.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,8 @@ value of the broadcast variable (e.g. if the variable is shipped to a new node l
11801180
Accumulators are variables that are only "added" to through an associative operation and can
11811181
therefore be efficiently supported in parallel. They can be used to implement counters (as in
11821182
MapReduce) or sums. Spark natively supports accumulators of numeric types, and programmers
1183-
can add support for new types.
1183+
can add support for new types. Accumulator values are displayed in Spark's UI and can be
1184+
useful for understanding the progress of running stages.
11841185

11851186
An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks
11861187
running on the cluster can then add to it using the `add` method or the `+=` operator (in Scala and Python).
@@ -1194,7 +1195,7 @@ The code below shows an accumulator being used to add up the elements of an arra
11941195
<div data-lang="scala" markdown="1">
11951196

11961197
{% highlight scala %}
1197-
scala> val accum = sc.accumulator(0)
1198+
scala> val accum = sc.accumulator(0, "My Accumulator")
11981199
accum: spark.Accumulator[Int] = 0
11991200

12001201
scala> sc.parallelize(Array(1, 2, 3, 4)).foreach(x => accum += x)

0 commit comments

Comments
 (0)