Skip to content

Commit 762a4d8

Browse files
committed
Merge pull request #1 from andrewor14/cleanup
I am merging this. I will take one more detailed look in the context of my original changes in the main PR.
2 parents 7edbc98 + f0aabb1 commit 762a4d8

38 files changed

+1460
-928
lines changed

core/src/main/scala/org/apache/spark/ContextCleaner.scala

Lines changed: 85 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,105 +21,106 @@ import java.lang.ref.{ReferenceQueue, WeakReference}
2121

2222
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
2323

24+
import org.apache.spark.broadcast.Broadcast
2425
import org.apache.spark.rdd.RDD
2526

26-
/** Listener class used for testing when any item has been cleaned by the Cleaner class */
27-
private[spark] trait CleanerListener {
28-
def rddCleaned(rddId: Int)
29-
def shuffleCleaned(shuffleId: Int)
30-
}
27+
/**
28+
* Classes that represent cleaning tasks.
29+
*/
30+
private sealed trait CleanupTask
31+
private case class CleanRDD(rddId: Int) extends CleanupTask
32+
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
33+
private case class CleanBroadcast(broadcastId: Long) extends CleanupTask
3134

3235
/**
33-
* Cleans RDDs and shuffle data.
36+
* A WeakReference associated with a CleanupTask.
37+
*
38+
* When the referent object becomes only weakly reachable, the corresponding
39+
* CleanupTaskWeakReference is automatically added to the given reference queue.
40+
*/
41+
private class CleanupTaskWeakReference(
42+
val task: CleanupTask,
43+
referent: AnyRef,
44+
referenceQueue: ReferenceQueue[AnyRef])
45+
extends WeakReference(referent, referenceQueue)
46+
47+
/**
48+
* An asynchronous cleaner for RDD, shuffle, and broadcast state.
49+
*
50+
* This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest,
51+
* to be processed when the associated object goes out of scope of the application. Actual
52+
* cleanup is performed in a separate daemon thread.
3453
*/
3554
private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
3655

37-
/** Classes to represent cleaning tasks */
38-
private sealed trait CleanupTask
39-
private case class CleanRDD(rddId: Int) extends CleanupTask
40-
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
41-
// TODO: add CleanBroadcast
56+
private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference]
57+
with SynchronizedBuffer[CleanupTaskWeakReference]
4258

43-
private val referenceBuffer = new ArrayBuffer[WeakReferenceWithCleanupTask]
44-
with SynchronizedBuffer[WeakReferenceWithCleanupTask]
4559
private val referenceQueue = new ReferenceQueue[AnyRef]
4660

4761
private val listeners = new ArrayBuffer[CleanerListener]
4862
with SynchronizedBuffer[CleanerListener]
4963

5064
private val cleaningThread = new Thread() { override def run() { keepCleaning() }}
5165

52-
private val REF_QUEUE_POLL_TIMEOUT = 100
53-
5466
@volatile private var stopped = false
5567

56-
private class WeakReferenceWithCleanupTask(referent: AnyRef, val task: CleanupTask)
57-
extends WeakReference(referent, referenceQueue)
68+
/** Attach a listener object to get information of when objects are cleaned. */
69+
def attachListener(listener: CleanerListener) {
70+
listeners += listener
71+
}
5872

59-
/** Start the cleaner */
73+
/** Start the cleaner. */
6074
def start() {
6175
cleaningThread.setDaemon(true)
6276
cleaningThread.setName("ContextCleaner")
6377
cleaningThread.start()
6478
}
6579

66-
/** Stop the cleaner */
80+
/** Stop the cleaner. */
6781
def stop() {
6882
stopped = true
6983
cleaningThread.interrupt()
7084
}
7185

72-
/**
73-
* Register a RDD for cleanup when it is garbage collected.
74-
*/
86+
/** Register a RDD for cleanup when it is garbage collected. */
7587
def registerRDDForCleanup(rdd: RDD[_]) {
7688
registerForCleanup(rdd, CleanRDD(rdd.id))
7789
}
7890

79-
/**
80-
* Register a shuffle dependency for cleanup when it is garbage collected.
81-
*/
91+
/** Register a ShuffleDependency for cleanup when it is garbage collected. */
8292
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
8393
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
8494
}
8595

86-
/** Cleanup RDD. */
87-
def cleanupRDD(rdd: RDD[_]) {
88-
doCleanupRDD(rdd.id)
89-
}
90-
91-
/** Cleanup shuffle. */
92-
def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
93-
doCleanupShuffle(shuffleDependency.shuffleId)
94-
}
95-
96-
/** Attach a listener object to get information of when objects are cleaned. */
97-
def attachListener(listener: CleanerListener) {
98-
listeners += listener
96+
/** Register a Broadcast for cleanup when it is garbage collected. */
97+
def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) {
98+
registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
9999
}
100100

101101
/** Register an object for cleanup. */
102102
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) {
103-
referenceBuffer += new WeakReferenceWithCleanupTask(objectForCleanup, task)
103+
referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
104104
}
105105

106-
/** Keep cleaning RDDs and shuffle data */
106+
/** Keep cleaning RDD, shuffle, and broadcast state. */
107107
private def keepCleaning() {
108-
while (!isStopped) {
108+
while (!stopped) {
109109
try {
110-
val reference = Option(referenceQueue.remove(REF_QUEUE_POLL_TIMEOUT))
111-
.map(_.asInstanceOf[WeakReferenceWithCleanupTask])
110+
val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT))
111+
.map(_.asInstanceOf[CleanupTaskWeakReference])
112112
reference.map(_.task).foreach { task =>
113113
logDebug("Got cleaning task " + task)
114114
referenceBuffer -= reference.get
115115
task match {
116116
case CleanRDD(rddId) => doCleanupRDD(rddId)
117117
case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId)
118+
case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId)
118119
}
119120
}
120121
} catch {
121122
case ie: InterruptedException =>
122-
if (!isStopped) logWarning("Cleaning thread interrupted")
123+
if (!stopped) logWarning("Cleaning thread interrupted")
123124
case t: Throwable => logError("Error in cleaning thread", t)
124125
}
125126
}
@@ -129,7 +130,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
129130
private def doCleanupRDD(rddId: Int) {
130131
try {
131132
logDebug("Cleaning RDD " + rddId)
132-
sc.unpersistRDD(rddId, false)
133+
sc.unpersistRDD(rddId, blocking = false)
133134
listeners.foreach(_.rddCleaned(rddId))
134135
logInfo("Cleaned RDD " + rddId)
135136
} catch {
@@ -150,10 +151,46 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
150151
}
151152
}
152153

153-
private def mapOutputTrackerMaster =
154-
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
154+
/** Perform broadcast cleanup. */
155+
private def doCleanupBroadcast(broadcastId: Long) {
156+
try {
157+
logDebug("Cleaning broadcast " + broadcastId)
158+
broadcastManager.unbroadcast(broadcastId, removeFromDriver = true)
159+
listeners.foreach(_.broadcastCleaned(broadcastId))
160+
logInfo("Cleaned broadcast " + broadcastId)
161+
} catch {
162+
case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t)
163+
}
164+
}
155165

156166
private def blockManagerMaster = sc.env.blockManager.master
167+
private def broadcastManager = sc.env.broadcastManager
168+
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
169+
170+
// Used for testing
171+
172+
def cleanupRDD(rdd: RDD[_]) {
173+
doCleanupRDD(rdd.id)
174+
}
175+
176+
def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
177+
doCleanupShuffle(shuffleDependency.shuffleId)
178+
}
157179

158-
private def isStopped = stopped
180+
def cleanupBroadcast[T](broadcast: Broadcast[T]) {
181+
doCleanupBroadcast(broadcast.id)
182+
}
183+
}
184+
185+
private object ContextCleaner {
186+
private val REF_QUEUE_POLL_TIMEOUT = 100
187+
}
188+
189+
/**
190+
* Listener class used for testing when any item has been cleaned by the Cleaner class.
191+
*/
192+
private[spark] trait CleanerListener {
193+
def rddCleaned(rddId: Int)
194+
def shuffleCleaned(shuffleId: Int)
195+
def broadcastCleaned(broadcastId: Long)
159196
}

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
112112
}
113113

114114
/**
115-
* Called from executors to get the server URIs and
116-
* output sizes of the map outputs of a given shuffle
115+
* Called from executors to get the server URIs and output sizes of the map outputs of
116+
* a given shuffle.
117117
*/
118118
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
119119
val statuses = mapStatuses.get(shuffleId).orNull
@@ -218,10 +218,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
218218
private var cacheEpoch = epoch
219219

220220
/**
221-
* Timestamp based HashMap for storing mapStatuses and cached serialized statuses
222-
* in the master, so that statuses are dropped only by explicit deregistering or
223-
* by TTL-based cleaning (if set). Other than these two
224-
* scenarios, nothing should be dropped from this HashMap.
221+
* Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master,
222+
* so that statuses are dropped only by explicit deregistering or by TTL-based cleaning (if set).
223+
* Other than these two scenarios, nothing should be dropped from this HashMap.
225224
*/
226225
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()
227226
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]()

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad
3535
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
3636
import org.apache.mesos.MesosNativeLibrary
3737

38-
import org.apache.spark.broadcast.Broadcast
3938
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
4039
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
4140
import org.apache.spark.rdd._
@@ -230,6 +229,7 @@ class SparkContext(
230229

231230
private[spark] val cleaner = new ContextCleaner(this)
232231
cleaner.start()
232+
233233
postEnvironmentUpdate()
234234

235235
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
@@ -643,7 +643,11 @@ class SparkContext(
643643
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
644644
* The variable will be sent to each cluster only once.
645645
*/
646-
def broadcast[T](value: T): Broadcast[T] = env.broadcastManager.newBroadcast[T](value, isLocal)
646+
def broadcast[T](value: T) = {
647+
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
648+
cleaner.registerBroadcastForCleanup(bc)
649+
bc
650+
}
647651

648652
/**
649653
* Add a file to be downloaded with this Spark job on every node.

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ object SparkEnv extends Logging {
185185
} else {
186186
new MapOutputTrackerWorker(conf)
187187
}
188+
188189
// Have to assign trackerActor after initialization as MapOutputTrackerActor
189190
// requires the MapOutputTracker itself
190191
mapOutputTracker.trackerActor = registerOrLookup(

core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
package org.apache.spark.broadcast
1919

2020
import java.io.Serializable
21-
import java.util.concurrent.atomic.AtomicLong
2221

23-
import org.apache.spark._
22+
import org.apache.spark.SparkException
2423

2524
/**
2625
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
@@ -51,49 +50,37 @@ import org.apache.spark._
5150
* @tparam T Type of the data contained in the broadcast variable.
5251
*/
5352
abstract class Broadcast[T](val id: Long) extends Serializable {
54-
def value: T
55-
56-
// We cannot have an abstract readObject here due to some weird issues with
57-
// readObject having to be 'private' in sub-classes.
58-
59-
override def toString = "Broadcast(" + id + ")"
60-
}
61-
62-
private[spark]
63-
class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager)
64-
extends Logging with Serializable {
65-
66-
private var initialized = false
67-
private var broadcastFactory: BroadcastFactory = null
6853

69-
initialize()
54+
protected var _isValid: Boolean = true
7055

71-
// Called by SparkContext or Executor before using Broadcast
72-
private def initialize() {
73-
synchronized {
74-
if (!initialized) {
75-
val broadcastFactoryClass = conf.get(
76-
"spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
56+
/**
57+
* Whether this Broadcast is actually usable. This should be false once persisted state is
58+
* removed from the driver.
59+
*/
60+
def isValid: Boolean = _isValid
7761

78-
broadcastFactory =
79-
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
80-
81-
// Initialize appropriate BroadcastFactory and BroadcastObject
82-
broadcastFactory.initialize(isDriver, conf, securityManager)
62+
def value: T
8363

84-
initialized = true
85-
}
64+
/**
65+
* Remove all persisted state associated with this broadcast on the executors. The next use
66+
* of this broadcast on the executors will trigger a remote fetch.
67+
*/
68+
def unpersist()
69+
70+
/**
71+
* Remove all persisted state associated with this broadcast on both the executors and the
72+
* driver. Overriding implementations should set isValid to false.
73+
*/
74+
private[spark] def destroy()
75+
76+
/**
77+
* If this broadcast is no longer valid, throw an exception.
78+
*/
79+
protected def assertValid() {
80+
if (!_isValid) {
81+
throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString))
8682
}
8783
}
8884

89-
def stop() {
90-
broadcastFactory.stop()
91-
}
92-
93-
private val nextBroadcastId = new AtomicLong(0)
94-
95-
def newBroadcast[T](value_ : T, isLocal: Boolean) =
96-
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
97-
98-
def isDriver = _isDriver
85+
override def toString = "Broadcast(" + id + ")"
9986
}

core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ import org.apache.spark.SparkConf
2727
* entire Spark job.
2828
*/
2929
trait BroadcastFactory {
30-
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
30+
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager)
3131
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
32-
def stop(): Unit
32+
def unbroadcast(id: Long, removeFromDriver: Boolean)
33+
def stop()
3334
}

0 commit comments

Comments
 (0)