Skip to content

Commit e61daa0

Browse files
committed
Modifications based on the comments on PR 126.
1 parent ae9da88 commit e61daa0

File tree

9 files changed

+71
-60
lines changed

9 files changed

+71
-60
lines changed

core/src/main/scala/org/apache/spark/ContextCleaner.scala

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
2121

2222
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
2323

24-
import org.apache.spark.rdd.RDD
25-
2624
/** Listener class used for testing when any item has been cleaned by the Cleaner class */
2725
private[spark] trait CleanerListener {
2826
def rddCleaned(rddId: Int)
@@ -32,12 +30,12 @@ private[spark] trait CleanerListener {
3230
/**
3331
* Cleans RDDs and shuffle data.
3432
*/
35-
private[spark] class ContextCleaner(env: SparkEnv) extends Logging {
33+
private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
3634

3735
/** Classes to represent cleaning tasks */
3836
private sealed trait CleaningTask
39-
private case class CleanRDD(sc: SparkContext, id: Int) extends CleaningTask
40-
private case class CleanShuffle(id: Int) extends CleaningTask
37+
private case class CleanRDD(rddId: Int) extends CleaningTask
38+
private case class CleanShuffle(shuffleId: Int) extends CleaningTask
4139
// TODO: add CleanBroadcast
4240

4341
private val queue = new LinkedBlockingQueue[CleaningTask]
@@ -47,7 +45,7 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging {
4745

4846
private val cleaningThread = new Thread() { override def run() { keepCleaning() }}
4947

50-
private var stopped = false
48+
@volatile private var stopped = false
5149

5250
/** Start the cleaner */
5351
def start() {
@@ -57,26 +55,37 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging {
5755

5856
/** Stop the cleaner */
5957
def stop() {
60-
synchronized { stopped = true }
58+
stopped = true
6159
cleaningThread.interrupt()
6260
}
6361

64-
/** Clean (unpersist) RDD data. */
65-
def cleanRDD(rdd: RDD[_]) {
66-
enqueue(CleanRDD(rdd.sparkContext, rdd.id))
67-
logDebug("Enqueued RDD " + rdd + " for cleaning up")
62+
/**
63+
* Clean (unpersist) RDD data. Do not perform any time or resource intensive
64+
* computation in this function as this is called from a finalize() function.
65+
*/
66+
def cleanRDD(rddId: Int) {
67+
enqueue(CleanRDD(rddId))
68+
logDebug("Enqueued RDD " + rddId + " for cleaning up")
6869
}
6970

70-
/** Clean shuffle data. */
71+
/**
72+
* Clean shuffle data. Do not perform any time or resource intensive
73+
* computation in this function as this is called from a finalize() function.
74+
*/
7175
def cleanShuffle(shuffleId: Int) {
7276
enqueue(CleanShuffle(shuffleId))
7377
logDebug("Enqueued shuffle " + shuffleId + " for cleaning up")
7478
}
7579

80+
/** Attach a listener object to get information of when objects are cleaned. */
7681
def attachListener(listener: CleanerListener) {
7782
listeners += listener
7883
}
79-
/** Enqueue a cleaning task */
84+
85+
/**
86+
* Enqueue a cleaning task. Do not perform any time or resource intensive
87+
* computation in this function as this is called from a finalize() function.
88+
*/
8089
private def enqueue(task: CleaningTask) {
8190
queue.put(task)
8291
}
@@ -86,24 +95,24 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging {
8695
try {
8796
while (!isStopped) {
8897
val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS))
89-
if (taskOpt.isDefined) {
98+
taskOpt.foreach(task => {
9099
logDebug("Got cleaning task " + taskOpt.get)
91-
taskOpt.get match {
92-
case CleanRDD(sc, rddId) => doCleanRDD(sc, rddId)
100+
task match {
101+
case CleanRDD(rddId) => doCleanRDD(sc, rddId)
93102
case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId)
94103
}
95-
}
104+
})
96105
}
97106
} catch {
98-
case ie: java.lang.InterruptedException =>
107+
case ie: InterruptedException =>
99108
if (!isStopped) logWarning("Cleaning thread interrupted")
100109
}
101110
}
102111

103112
/** Perform RDD cleaning */
104113
private def doCleanRDD(sc: SparkContext, rddId: Int) {
105114
logDebug("Cleaning rdd " + rddId)
106-
sc.env.blockManager.master.removeRdd(rddId, false)
115+
blockManagerMaster.removeRdd(rddId, false)
107116
sc.persistentRdds.remove(rddId)
108117
listeners.foreach(_.rddCleaned(rddId))
109118
logInfo("Cleaned rdd " + rddId)
@@ -113,14 +122,14 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging {
113122
private def doCleanShuffle(shuffleId: Int) {
114123
logDebug("Cleaning shuffle " + shuffleId)
115124
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
116-
blockManager.master.removeShuffle(shuffleId)
125+
blockManagerMaster.removeShuffle(shuffleId)
117126
listeners.foreach(_.shuffleCleaned(shuffleId))
118127
logInfo("Cleaned shuffle " + shuffleId)
119128
}
120129

121-
private def mapOutputTrackerMaster = env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
130+
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
122131

123-
private def blockManager = env.blockManager
132+
private def blockManagerMaster = sc.env.blockManager.master
124133

125-
private def isStopped = synchronized { stopped }
134+
private def isStopped = stopped
126135
}

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,18 @@
1717

1818
package org.apache.spark
1919

20-
import scala.Some
21-
import scala.collection.mutable.{HashSet, Map}
22-
import scala.concurrent.Await
23-
2420
import java.io._
2521
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
2622

27-
import scala.collection.mutable.HashSet
23+
import scala.Some
24+
import scala.collection.mutable.{HashSet, Map}
2825
import scala.concurrent.Await
2926

3027
import akka.actor._
3128
import akka.pattern.ask
32-
3329
import org.apache.spark.scheduler.MapStatus
3430
import org.apache.spark.storage.BlockManagerId
35-
import org.apache.spark.util.{AkkaUtils, TimeStampedHashMap, BoundedHashMap}
31+
import org.apache.spark.util._
3632

3733
private[spark] sealed trait MapOutputTrackerMessage
3834
private[spark] case class GetMapOutputStatuses(shuffleId: Int)
@@ -55,7 +51,7 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
5551
}
5652

5753
/**
58-
* Class that keeps track of the location of the location of the mapt output of
54+
* Class that keeps track of the location of the location of the map output of
5955
* a stage. This is abstract because different versions of MapOutputTracker
6056
* (driver and worker) use different HashMap to store its metadata.
6157
*/
@@ -155,10 +151,6 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
155151
}
156152
}
157153

158-
protected def cleanup(cleanupTime: Long) {
159-
mapStatuses.asInstanceOf[TimeStampedHashMap[_, _]].clearOldValues(cleanupTime)
160-
}
161-
162154
def stop() {
163155
communicate(StopMapOutputTracker)
164156
mapStatuses.clear()
@@ -195,10 +187,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
195187
/**
196188
* Bounded HashMap for storing serialized statuses in the worker. This allows
197189
* the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be
198-
* automatically repopulated by fetching them again from the driver.
190+
* automatically repopulated by fetching them again from the driver. Its okay to
191+
* keep the cache size small as it unlikely that there will be a very large number of
192+
* stages active simultaneously in the worker.
199193
*/
200-
protected val MAX_MAP_STATUSES = 100
201-
protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](MAX_MAP_STATUSES, true)
194+
protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](
195+
conf.getInt("spark.mapOutputTracker.cacheSize", 100), true
196+
)
202197
}
203198

204199
/**
@@ -212,20 +207,18 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
212207
private var cacheEpoch = epoch
213208

214209
/**
215-
* Timestamp based HashMap for storing mapStatuses in the master, so that statuses are dropped
216-
* only by explicit deregistering or by ttl-based cleaning (if set). Other than these two
210+
* Timestamp based HashMap for storing mapStatuses and cached serialized statuses
211+
* in the master, so that statuses are dropped only by explicit deregistering or
212+
* by TTL-based cleaning (if set). Other than these two
217213
* scenarios, nothing should be dropped from this HashMap.
218214
*/
215+
219216
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()
217+
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]()
220218

221-
/**
222-
* Bounded HashMap for storing serialized statuses in the master. This allows
223-
* the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be
224-
* automatically repopulated by serializing the lost statuses again .
225-
*/
226-
protected val MAX_SERIALIZED_STATUSES = 100
227-
private val cachedSerializedStatuses =
228-
new BoundedHashMap[Int, Array[Byte]](MAX_SERIALIZED_STATUSES, true)
219+
// For cleaning up TimeStampedHashMaps
220+
private val metadataCleaner =
221+
new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)
229222

230223
def registerShuffle(shuffleId: Int, numMaps: Int) {
231224
if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
@@ -264,6 +257,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
264257

265258
def unregisterShuffle(shuffleId: Int) {
266259
mapStatuses.remove(shuffleId)
260+
cachedSerializedStatuses.remove(shuffleId)
267261
}
268262

269263
def incrementEpoch() {
@@ -303,20 +297,22 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
303297
}
304298

305299
def contains(shuffleId: Int): Boolean = {
306-
mapStatuses.contains(shuffleId)
300+
cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
307301
}
308302

309303
override def stop() {
310304
super.stop()
305+
metadataCleaner.cancel()
311306
cachedSerializedStatuses.clear()
312307
}
313308

314309
override def updateEpoch(newEpoch: Long) {
315310
// This might be called on the MapOutputTrackerMaster if we're running in local mode.
316311
}
317312

318-
def has(shuffleId: Int): Boolean = {
319-
cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId)
313+
protected def cleanup(cleanupTime: Long) {
314+
mapStatuses.clearOldValues(cleanupTime)
315+
cachedSerializedStatuses.clearOldValues(cleanupTime)
320316
}
321317
}
322318

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class SparkContext(
206206
@volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
207207
dagScheduler.start()
208208

209-
private[spark] val cleaner = new ContextCleaner(env)
209+
private[spark] val cleaner = new ContextCleaner(this)
210210
cleaner.start()
211211

212212
ui.start()

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ abstract class RDD[T: ClassTag](
10271027

10281028
def cleanup() {
10291029
logInfo("Cleanup called on RDD " + id)
1030-
sc.cleaner.cleanRDD(this)
1030+
sc.cleaner.cleanRDD(id)
10311031
dependencies.filter(_.isInstanceOf[ShuffleDependency[_, _]])
10321032
.map(_.asInstanceOf[ShuffleDependency[_, _]].shuffleId)
10331033
.foreach(sc.cleaner.cleanShuffle)

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ class DAGScheduler(
266266
: Stage =
267267
{
268268
val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite)
269-
if (mapOutputTracker.has(shuffleDep.shuffleId)) {
269+
if (mapOutputTracker.contains(shuffleDep.shuffleId)) {
270270
val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
271271
val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
272272
for (i <- 0 until locs.size) {

core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,14 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
169169
throw new IllegalStateException("Failed to find shuffle block: " + id)
170170
}
171171

172-
/** Remove all the blocks / files related to a particular shuffle */
172+
/** Remove all the blocks / files and metadata related to a particular shuffle */
173173
def removeShuffle(shuffleId: ShuffleId) {
174+
removeShuffleBlocks(shuffleId)
175+
shuffleStates.remove(shuffleId)
176+
}
177+
178+
/** Remove all the blocks / files related to a particular shuffle */
179+
private def removeShuffleBlocks(shuffleId: ShuffleId) {
174180
shuffleStates.get(shuffleId) match {
175181
case Some(state) =>
176182
if (consolidateShuffleFiles) {
@@ -194,7 +200,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
194200
}
195201

196202
private def cleanup(cleanupTime: Long) {
197-
shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffle(shuffleId))
203+
shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
198204
}
199205
}
200206

core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ private[spark] class MetadataCleaner(
6262

6363
private[spark] object MetadataCleanerType extends Enumeration {
6464

65-
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, BLOCK_MANAGER,
66-
SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS, CLEANER = Value
65+
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER,
66+
SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
6767

6868
type MetadataCleanerType = Value
6969

core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
2525
val rdd = newRDD.persist()
2626
rdd.count()
2727
val tester = new CleanerTester(sc, rddIds = Seq(rdd.id))
28-
cleaner.cleanRDD(rdd)
28+
cleaner.cleanRDD(rdd.id)
2929
tester.assertCleanup
3030
}
3131

core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,4 +206,4 @@ class TestMap[A, B] extends WrappedJavaHashMap[A, B, A, B] {
206206
protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = {
207207
new TestMap[K1, V1]
208208
}
209-
}
209+
}

0 commit comments

Comments
 (0)