Skip to content

Commit d0edef3

Browse files
committed
Add framework for broadcast cleanup
As of this commit, Spark does not clean up broadcast blocks. This will be done in the next commit.
1 parent ba52e00 commit d0edef3

12 files changed

+249
-112
lines changed

core/src/main/scala/org/apache/spark/ContextCleaner.scala

Lines changed: 86 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,105 +21,106 @@ import java.lang.ref.{ReferenceQueue, WeakReference}
2121

2222
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
2323

24+
import org.apache.spark.broadcast.Broadcast
2425
import org.apache.spark.rdd.RDD
2526

26-
/** Listener class used for testing when any item has been cleaned by the Cleaner class */
27-
private[spark] trait CleanerListener {
28-
def rddCleaned(rddId: Int)
29-
def shuffleCleaned(shuffleId: Int)
30-
}
27+
/**
28+
* Classes that represent cleaning tasks.
29+
*/
30+
private sealed trait CleanupTask
31+
private case class CleanRDD(rddId: Int) extends CleanupTask
32+
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
33+
private case class CleanBroadcast(broadcastId: Long) extends CleanupTask
3134

3235
/**
33-
* Cleans RDDs and shuffle data.
36+
* A WeakReference associated with a CleanupTask.
37+
*
38+
* When the referent object becomes only weakly reachable, the corresponding
39+
* CleanupTaskWeakReference is automatically added to the given reference queue.
40+
*/
41+
private class CleanupTaskWeakReference(
42+
val task: CleanupTask,
43+
referent: AnyRef,
44+
referenceQueue: ReferenceQueue[AnyRef])
45+
extends WeakReference(referent, referenceQueue)
46+
47+
/**
48+
* An asynchronous cleaner for RDD, shuffle, and broadcast state.
49+
*
50+
* This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest,
51+
* to be processed when the associated object goes out of scope of the application. Actual
52+
* cleanup is performed in a separate daemon thread.
3453
*/
3554
private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
3655

37-
/** Classes to represent cleaning tasks */
38-
private sealed trait CleanupTask
39-
private case class CleanRDD(rddId: Int) extends CleanupTask
40-
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
41-
// TODO: add CleanBroadcast
56+
private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference]
57+
with SynchronizedBuffer[CleanupTaskWeakReference]
4258

43-
private val referenceBuffer = new ArrayBuffer[WeakReferenceWithCleanupTask]
44-
with SynchronizedBuffer[WeakReferenceWithCleanupTask]
4559
private val referenceQueue = new ReferenceQueue[AnyRef]
4660

4761
private val listeners = new ArrayBuffer[CleanerListener]
4862
with SynchronizedBuffer[CleanerListener]
4963

5064
private val cleaningThread = new Thread() { override def run() { keepCleaning() }}
5165

52-
private val REF_QUEUE_POLL_TIMEOUT = 100
53-
5466
@volatile private var stopped = false
5567

56-
private class WeakReferenceWithCleanupTask(referent: AnyRef, val task: CleanupTask)
57-
extends WeakReference(referent, referenceQueue)
68+
/** Attach a listener object to get information of when objects are cleaned. */
69+
def attachListener(listener: CleanerListener) {
70+
listeners += listener
71+
}
5872

59-
/** Start the cleaner */
73+
/** Start the cleaner. */
6074
def start() {
6175
cleaningThread.setDaemon(true)
6276
cleaningThread.setName("ContextCleaner")
6377
cleaningThread.start()
6478
}
6579

66-
/** Stop the cleaner */
80+
/** Stop the cleaner. */
6781
def stop() {
6882
stopped = true
6983
cleaningThread.interrupt()
7084
}
7185

72-
/**
73-
* Register a RDD for cleanup when it is garbage collected.
74-
*/
86+
/** Register a RDD for cleanup when it is garbage collected. */
7587
def registerRDDForCleanup(rdd: RDD[_]) {
7688
registerForCleanup(rdd, CleanRDD(rdd.id))
7789
}
7890

79-
/**
80-
* Register a shuffle dependency for cleanup when it is garbage collected.
81-
*/
91+
/** Register a ShuffleDependency for cleanup when it is garbage collected. */
8292
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
8393
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
8494
}
8595

86-
/** Cleanup RDD. */
87-
def cleanupRDD(rdd: RDD[_]) {
88-
doCleanupRDD(rdd.id)
89-
}
90-
91-
/** Cleanup shuffle. */
92-
def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
93-
doCleanupShuffle(shuffleDependency.shuffleId)
94-
}
95-
96-
/** Attach a listener object to get information of when objects are cleaned. */
97-
def attachListener(listener: CleanerListener) {
98-
listeners += listener
96+
/** Register a Broadcast for cleanup when it is garbage collected. */
97+
def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) {
98+
registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
9999
}
100100

101101
/** Register an object for cleanup. */
102102
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) {
103-
referenceBuffer += new WeakReferenceWithCleanupTask(objectForCleanup, task)
103+
referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
104104
}
105105

106-
/** Keep cleaning RDDs and shuffle data */
106+
/** Keep cleaning RDD, shuffle, and broadcast state. */
107107
private def keepCleaning() {
108-
while (!isStopped) {
108+
while (!stopped) {
109109
try {
110-
val reference = Option(referenceQueue.remove(REF_QUEUE_POLL_TIMEOUT))
111-
.map(_.asInstanceOf[WeakReferenceWithCleanupTask])
110+
val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT))
111+
.map(_.asInstanceOf[CleanupTaskWeakReference])
112112
reference.map(_.task).foreach { task =>
113113
logDebug("Got cleaning task " + task)
114114
referenceBuffer -= reference.get
115115
task match {
116116
case CleanRDD(rddId) => doCleanupRDD(rddId)
117117
case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId)
118+
case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId)
118119
}
119120
}
120121
} catch {
121122
case ie: InterruptedException =>
122-
if (!isStopped) logWarning("Cleaning thread interrupted")
123+
if (!stopped) logWarning("Cleaning thread interrupted")
123124
case t: Throwable => logError("Error in cleaning thread", t)
124125
}
125126
}
@@ -129,7 +130,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
129130
private def doCleanupRDD(rddId: Int) {
130131
try {
131132
logDebug("Cleaning RDD " + rddId)
132-
sc.unpersistRDD(rddId, false)
133+
sc.unpersistRDD(rddId, blocking = false)
133134
listeners.foreach(_.rddCleaned(rddId))
134135
logInfo("Cleaned RDD " + rddId)
135136
} catch {
@@ -150,10 +151,47 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
150151
}
151152
}
152153

153-
private def mapOutputTrackerMaster =
154-
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
154+
/** Perform broadcast cleanup. */
155+
private def doCleanupBroadcast(broadcastId: Long) {
156+
try {
157+
logDebug("Cleaning broadcast " + broadcastId)
158+
broadcastManager.unbroadcast(broadcastId, removeFromDriver = true)
159+
listeners.foreach(_.broadcastCleaned(broadcastId))
160+
logInfo("Cleaned broadcast " + broadcastId)
161+
} catch {
162+
case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t)
163+
}
164+
}
155165

156166
private def blockManagerMaster = sc.env.blockManager.master
167+
private def broadcastManager = sc.env.broadcastManager
168+
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
169+
170+
// Used for testing
171+
172+
private[spark] def cleanupRDD(rdd: RDD[_]) {
173+
doCleanupRDD(rdd.id)
174+
}
175+
176+
private[spark] def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
177+
doCleanupShuffle(shuffleDependency.shuffleId)
178+
}
157179

158-
private def isStopped = stopped
180+
private[spark] def cleanupBroadcast[T](broadcast: Broadcast[T]) {
181+
doCleanupBroadcast(broadcast.id)
182+
}
183+
184+
}
185+
186+
private object ContextCleaner {
187+
private val REF_QUEUE_POLL_TIMEOUT = 100
188+
}
189+
190+
/**
191+
* Listener class used for testing when any item has been cleaned by the Cleaner class.
192+
*/
193+
private[spark] trait CleanerListener {
194+
def rddCleaned(rddId: Int)
195+
def shuffleCleaned(shuffleId: Int)
196+
def broadcastCleaned(broadcastId: Long)
159197
}

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,11 @@ class SparkContext(
642642
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
643643
* The variable will be sent to each cluster only once.
644644
*/
645-
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
645+
def broadcast[T](value: T) = {
646+
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
647+
cleaner.registerBroadcastForCleanup(bc)
648+
bc
649+
}
646650

647651
/**
648652
* Add a file to be downloaded with this Spark job on every node.

core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ import java.io.Serializable
5050
abstract class Broadcast[T](val id: Long) extends Serializable {
5151
def value: T
5252

53+
/**
54+
* Remove all persisted state associated with this broadcast.
55+
* @param removeFromDriver Whether to remove state from the driver.
56+
*/
57+
def unpersist(removeFromDriver: Boolean)
58+
5359
// We cannot have an abstract readObject here due to some weird issues with
5460
// readObject having to be 'private' in sub-classes.
5561

core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@ import org.apache.spark.SparkConf
2929
trait BroadcastFactory {
3030
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
3131
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
32+
def unbroadcast(id: Long, removeFromDriver: Boolean)
3233
def stop(): Unit
3334
}

core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,8 @@ private[spark] class BroadcastManager(
6060
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
6161
}
6262

63+
def unbroadcast(id: Long, removeFromDriver: Boolean) {
64+
broadcastFactory.unbroadcast(id, removeFromDriver)
65+
}
66+
6367
}

0 commit comments

Comments
 (0)