Skip to content

Commit 537baad

Browse files
author
Nathan Kronenfeld
committed
Fuller refactoring as intended, incorporating JR's suggestions for ThreadLocal localAccums, and keeping clear(), but also calling it in tasks' finally block, rather than just at the beginning of the task.
1 parent 39a82f2 commit 537baad

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

core/src/main/scala/org/apache/spark/Accumulators.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark
1919

2020
import java.io.{ObjectInputStream, Serializable}
2121
import java.util.concurrent.atomic.AtomicLong
22+
import java.lang.ThreadLocal
2223

2324
import scala.collection.generic.Growable
2425
import scala.collection.mutable.Map
@@ -281,7 +282,9 @@ object AccumulatorParam {
281282
private object Accumulators {
282283
// TODO: Use soft references? => need to make readObject work properly then
283284
val originals = Map[Long, Accumulable[_, _]]()
284-
val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]()
285+
val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
286+
override protected def initialValue() = Map[Long, Accumulable[_, _]]()
287+
}
285288
var lastId: Long = 0
286289

287290
def newId(): Long = synchronized {
@@ -293,25 +296,23 @@ private object Accumulators {
293296
if (original) {
294297
originals(a.id) = a
295298
} else {
296-
val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map())
297-
accums(a.id) = a
299+
localAccums.get()(a.id) = a
298300
}
299301
}
300302

301303
// Clear the local (non-original) accumulators for the current thread
302304
def clear() {
303305
synchronized {
304-
localAccums.remove(Thread.currentThread)
306+
localAccums.get.clear
305307
}
306308
}
307309

308310
// Get the values of the local accumulators for the current thread (by ID)
309311
def values: Map[Long, Any] = synchronized {
310312
val ret = Map[Long, Any]()
311-
for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
313+
for ((id, accum) <- localAccums.get) {
312314
ret(id) = accum.localValue
313315
}
314-
localAccums.remove(Thread.currentThread)
315316
return ret
316317
}
317318

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ private[spark] class Executor(
172172
val startGCTime = gcTime
173173

174174
try {
175-
Accumulators.clear()
176175
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
177176
updateDependencies(taskFiles, taskJars)
178177
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
@@ -278,6 +277,8 @@ private[spark] class Executor(
278277
env.shuffleMemoryManager.releaseMemoryForThisThread()
279278
// Release memory used by this thread for unrolling blocks
280279
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
280+
// Release memory used by this thread for accumulators
281+
Accumulators.clear()
281282
runningTasks.remove(taskId)
282283
}
283284
}

0 commit comments

Comments
 (0)