Skip to content

Commit e427a9e

Browse files
committed
Added ContextCleaner to automatically clean RDDs and shuffles when they fall out of scope. Also replaced TimeStampedHashMap to BoundedHashMaps and TimeStampedWeakValueHashMap for the necessary hashmap behavior.
1 parent 3a9d82c commit e427a9e

22 files changed

+946
-60
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark
19+
20+
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
21+
22+
import java.util.concurrent.{ArrayBlockingQueue, TimeUnit}
23+
24+
import org.apache.spark.rdd.RDD
25+
26+
/** Listener class used for testing when any item has been cleaned by the Cleaner class */
27+
private[spark] trait CleanerListener {
28+
def rddCleaned(rddId: Int)
29+
def shuffleCleaned(shuffleId: Int)
30+
}
31+
32+
/**
33+
* Cleans RDDs and shuffle data. This should be instantiated only on the driver.
34+
*/
35+
private[spark] class ContextCleaner(env: SparkEnv) extends Logging {
36+
37+
/** Classes to represent cleaning tasks */
38+
private sealed trait CleaningTask
39+
private case class CleanRDD(sc: SparkContext, id: Int) extends CleaningTask
40+
private case class CleanShuffle(id: Int) extends CleaningTask
41+
// TODO: add CleanBroadcast
42+
43+
private val QUEUE_CAPACITY = 1000
44+
private val queue = new ArrayBlockingQueue[CleaningTask](QUEUE_CAPACITY)
45+
46+
protected val listeners = new ArrayBuffer[CleanerListener]
47+
with SynchronizedBuffer[CleanerListener]
48+
49+
private val cleaningThread = new Thread() { override def run() { keepCleaning() }}
50+
51+
private var stopped = false
52+
53+
/** Start the cleaner */
54+
def start() {
55+
cleaningThread.setDaemon(true)
56+
cleaningThread.start()
57+
}
58+
59+
/** Stop the cleaner */
60+
def stop() {
61+
synchronized { stopped = true }
62+
cleaningThread.interrupt()
63+
}
64+
65+
/** Clean all data and metadata related to a RDD, including shuffle files and metadata */
66+
def cleanRDD(rdd: RDD[_]) {
67+
enqueue(CleanRDD(rdd.sparkContext, rdd.id))
68+
logDebug("Enqueued RDD " + rdd + " for cleaning up")
69+
}
70+
71+
def cleanShuffle(shuffleId: Int) {
72+
enqueue(CleanShuffle(shuffleId))
73+
logDebug("Enqueued shuffle " + shuffleId + " for cleaning up")
74+
}
75+
76+
def attachListener(listener: CleanerListener) {
77+
listeners += listener
78+
}
79+
/** Enqueue a cleaning task */
80+
private def enqueue(task: CleaningTask) {
81+
queue.put(task)
82+
}
83+
84+
/** Keep cleaning RDDs and shuffle data */
85+
private def keepCleaning() {
86+
try {
87+
while (!isStopped) {
88+
val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS))
89+
if (taskOpt.isDefined) {
90+
logDebug("Got cleaning task " + taskOpt.get)
91+
taskOpt.get match {
92+
case CleanRDD(sc, rddId) => doCleanRDD(sc, rddId)
93+
case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId)
94+
}
95+
}
96+
}
97+
} catch {
98+
case ie: java.lang.InterruptedException =>
99+
if (!isStopped) logWarning("Cleaning thread interrupted")
100+
}
101+
}
102+
103+
/** Perform RDD cleaning */
104+
private def doCleanRDD(sc: SparkContext, rddId: Int) {
105+
logDebug("Cleaning rdd "+ rddId)
106+
sc.env.blockManager.master.removeRdd(rddId, false)
107+
sc.persistentRdds.remove(rddId)
108+
listeners.foreach(_.rddCleaned(rddId))
109+
logInfo("Cleaned rdd "+ rddId)
110+
}
111+
112+
/** Perform shuffle cleaning */
113+
private def doCleanShuffle(shuffleId: Int) {
114+
logDebug("Cleaning shuffle "+ shuffleId)
115+
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
116+
blockManager.master.removeShuffle(shuffleId)
117+
listeners.foreach(_.shuffleCleaned(shuffleId))
118+
logInfo("Cleaned shuffle " + shuffleId)
119+
}
120+
121+
private def mapOutputTrackerMaster = env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
122+
123+
private def blockManager = env.blockManager
124+
125+
private def isStopped = synchronized { stopped }
126+
}

core/src/main/scala/org/apache/spark/Dependency.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ class ShuffleDependency[K, V](
5252
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
5353

5454
val shuffleId: Int = rdd.context.newShuffleId()
55+
56+
override def finalize() {
57+
if (rdd != null) {
58+
rdd.sparkContext.cleaner.cleanShuffle(shuffleId)
59+
}
60+
}
5561
}
5662

5763

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@
1717

1818
package org.apache.spark
1919

20+
import scala.Some
21+
import scala.collection.mutable.{HashSet, Map}
22+
import scala.concurrent.Await
23+
2024
import java.io._
2125
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
2226

23-
import scala.collection.mutable.HashSet
24-
import scala.concurrent.Await
25-
import scala.concurrent.duration._
26-
2727
import akka.actor._
2828
import akka.pattern.ask
2929

3030
import org.apache.spark.scheduler.MapStatus
3131
import org.apache.spark.storage.BlockManagerId
32-
import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
32+
import org.apache.spark.util._
3333

3434
private[spark] sealed trait MapOutputTrackerMessage
3535
private[spark] case class GetMapOutputStatuses(shuffleId: Int)
@@ -51,23 +51,21 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
5151
}
5252
}
5353

54-
private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
54+
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
5555

5656
private val timeout = AkkaUtils.askTimeout(conf)
5757

5858
// Set to the MapOutputTrackerActor living on the driver
5959
var trackerActor: ActorRef = _
6060

61-
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
61+
/** This HashMap needs to have different storage behavior for driver and worker */
62+
protected val mapStatuses: Map[Int, Array[MapStatus]]
6263

6364
// Incremented every time a fetch fails so that client nodes know to clear
6465
// their cache of map output locations if this happens.
6566
protected var epoch: Long = 0
6667
protected val epochLock = new java.lang.Object
6768

68-
private val metadataCleaner =
69-
new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)
70-
7169
// Send a message to the trackerActor and get its result within a default timeout, or
7270
// throw a SparkException if this fails.
7371
private def askTracker(message: Any): Any = {
@@ -138,8 +136,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
138136
fetchedStatuses.synchronized {
139137
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
140138
}
141-
}
142-
else {
139+
} else {
143140
throw new FetchFailedException(null, shuffleId, -1, reduceId,
144141
new Exception("Missing all output locations for shuffle " + shuffleId))
145142
}
@@ -151,13 +148,12 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
151148
}
152149

153150
protected def cleanup(cleanupTime: Long) {
154-
mapStatuses.clearOldValues(cleanupTime)
151+
mapStatuses.asInstanceOf[TimeStampedHashMap[_, _]].clearOldValues(cleanupTime)
155152
}
156153

157154
def stop() {
158155
communicate(StopMapOutputTracker)
159156
mapStatuses.clear()
160-
metadataCleaner.cancel()
161157
trackerActor = null
162158
}
163159

@@ -182,15 +178,42 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
182178
}
183179
}
184180

181+
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
182+
183+
/**
184+
* Bounded HashMap for storing serialized statuses in the worker. This allows
185+
* the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be
186+
* automatically repopulated by fetching them again from the driver.
187+
*/
188+
protected val MAX_MAP_STATUSES = 100
189+
protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](MAX_MAP_STATUSES, true)
190+
}
191+
192+
185193
private[spark] class MapOutputTrackerMaster(conf: SparkConf)
186194
extends MapOutputTracker(conf) {
187195

188196
// Cache a serialized version of the output statuses for each shuffle to send them out faster
189197
private var cacheEpoch = epoch
190-
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
198+
199+
/**
200+
* Timestamp based HashMap for storing mapStatuses in the master, so that statuses are dropped
201+
* only by explicit deregistering or by ttl-based cleaning (if set). Other than these two
202+
* scenarios, nothing should be dropped from this HashMap.
203+
*/
204+
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()
205+
206+
/**
207+
* Bounded HashMap for storing serialized statuses in the master. This allows
208+
* the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be
209+
* automatically repopulated by serializing the lost statuses again .
210+
*/
211+
protected val MAX_SERIALIZED_STATUSES = 100
212+
private val cachedSerializedStatuses =
213+
new BoundedHashMap[Int, Array[Byte]](MAX_SERIALIZED_STATUSES, true)
191214

192215
def registerShuffle(shuffleId: Int, numMaps: Int) {
193-
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
216+
if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
194217
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
195218
}
196219
}
@@ -224,6 +247,10 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
224247
}
225248
}
226249

250+
def unregisterShuffle(shuffleId: Int) {
251+
mapStatuses.remove(shuffleId)
252+
}
253+
227254
def incrementEpoch() {
228255
epochLock.synchronized {
229256
epoch += 1
@@ -260,9 +287,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
260287
bytes
261288
}
262289

263-
protected override def cleanup(cleanupTime: Long) {
264-
super.cleanup(cleanupTime)
265-
cachedSerializedStatuses.clearOldValues(cleanupTime)
290+
def contains(shuffleId: Int): Boolean = {
291+
mapStatuses.contains(shuffleId)
266292
}
267293

268294
override def stop() {

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me
4848
import org.apache.spark.scheduler.local.LocalBackend
4949
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
5050
import org.apache.spark.ui.SparkUI
51-
import org.apache.spark.util.{Utils, TimeStampedHashMap, MetadataCleaner, MetadataCleanerType,
52-
ClosureCleaner}
51+
import org.apache.spark.util._
52+
import scala.Some
53+
import org.apache.spark.storage.RDDInfo
54+
import org.apache.spark.storage.StorageStatus
5355

5456
/**
5557
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -150,7 +152,7 @@ class SparkContext(
150152
private[spark] val addedJars = HashMap[String, Long]()
151153

152154
// Keeps track of all persisted RDDs
153-
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
155+
private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]]
154156
private[spark] val metadataCleaner =
155157
new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf)
156158

@@ -202,6 +204,9 @@ class SparkContext(
202204
@volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
203205
dagScheduler.start()
204206

207+
private[spark] val cleaner = new ContextCleaner(env)
208+
cleaner.start()
209+
205210
ui.start()
206211

207212
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
@@ -784,6 +789,7 @@ class SparkContext(
784789
dagScheduler = null
785790
if (dagSchedulerCopy != null) {
786791
metadataCleaner.cancel()
792+
cleaner.stop()
787793
dagSchedulerCopy.stop()
788794
taskScheduler = null
789795
// TODO: Cache.stop()?

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ object SparkEnv extends Logging {
181181
val mapOutputTracker = if (isDriver) {
182182
new MapOutputTrackerMaster(conf)
183183
} else {
184-
new MapOutputTracker(conf)
184+
new MapOutputTrackerWorker(conf)
185185
}
186186
mapOutputTracker.trackerActor = registerOrLookup(
187187
"MapOutputTracker",

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,13 @@ abstract class RDD[T: ClassTag](
10121012
checkpointData.flatMap(_.getCheckpointFile)
10131013
}
10141014

1015+
def cleanup() {
1016+
sc.cleaner.cleanRDD(this)
1017+
dependencies.filter(_.isInstanceOf[ShuffleDependency[_, _]])
1018+
.map(_.asInstanceOf[ShuffleDependency[_, _]].shuffleId)
1019+
.foreach(sc.cleaner.cleanShuffle)
1020+
}
1021+
10151022
// =======================================================================
10161023
// Other internal methods and fields
10171024
// =======================================================================
@@ -1091,4 +1098,7 @@ abstract class RDD[T: ClassTag](
10911098
new JavaRDD(this)(elementClassTag)
10921099
}
10931100

1101+
override def finalize() {
1102+
cleanup()
1103+
}
10941104
}

core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,16 @@ import java.io._
2121
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
2222

2323
import org.apache.spark._
24-
import org.apache.spark.rdd.RDD
25-
import org.apache.spark.rdd.RDDCheckpointData
26-
import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
24+
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
25+
import org.apache.spark.util.BoundedHashMap
2726

2827
private[spark] object ResultTask {
2928

3029
// A simple map between the stage id to the serialized byte array of a task.
3130
// Served as a cache for task serialization because serialization can be
3231
// expensive on the master node if it needs to launch thousands of tasks.
33-
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
34-
35-
// TODO: This object shouldn't have global variables
36-
val metadataCleaner = new MetadataCleaner(
37-
MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues, new SparkConf)
32+
val MAX_CACHE_SIZE = 100
33+
val serializedInfoCache = new BoundedHashMap[Int, Array[Byte]](MAX_CACHE_SIZE, true)
3834

3935
def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
4036
synchronized {

0 commit comments

Comments
 (0)