Skip to content

Commit 892b952

Browse files
committed
Removed use of BoundedHashMap, and made BlockManagerSlaveActor cleanup shuffle metadata in MapOutputTrackerWorker.
1 parent a7260d3 commit 892b952

17 files changed

+196
-147
lines changed

core/src/main/scala/org/apache/spark/ContextCleaner.scala

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark
2020
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
2121

2222
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
23+
import org.apache.spark.storage.StorageLevel
2324

2425
/** Listener class used for testing when any item has been cleaned by the Cleaner class */
2526
private[spark] trait CleanerListener {
@@ -61,19 +62,19 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
6162
}
6263

6364
/**
64-
* Clean RDD data. Do not perform any time or resource intensive
65+
* Schedule cleanup of RDD data. Do not perform any time or resource intensive
6566
* computation in this function as this is called from a finalize() function.
6667
*/
67-
def cleanRDD(rddId: Int) {
68+
def scheduleRDDCleanup(rddId: Int) {
6869
enqueue(CleanRDD(rddId))
6970
logDebug("Enqueued RDD " + rddId + " for cleaning up")
7071
}
7172

7273
/**
73-
* Clean shuffle data. Do not perform any time or resource intensive
74+
* Schedule cleanup of shuffle data. Do not perform any time or resource intensive
7475
* computation in this function as this is called from a finalize() function.
7576
*/
76-
def cleanShuffle(shuffleId: Int) {
77+
def scheduleShuffleCleanup(shuffleId: Int) {
7778
enqueue(CleanShuffle(shuffleId))
7879
logDebug("Enqueued shuffle " + shuffleId + " for cleaning up")
7980
}
@@ -83,6 +84,13 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
8384
listeners += listener
8485
}
8586

87+
/** Unpersists RDD and remove all blocks for it from memory and disk. */
88+
def unpersistRDD(rddId: Int, blocking: Boolean) {
89+
logDebug("Unpersisted RDD " + rddId)
90+
sc.env.blockManager.master.removeRdd(rddId, blocking)
91+
sc.persistentRdds.remove(rddId)
92+
}
93+
8694
/**
8795
* Enqueue a cleaning task. Do not perform any time or resource intensive
8896
* computation in this function as this is called from a finalize() function.
@@ -115,8 +123,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
115123
private def doCleanRDD(rddId: Int) {
116124
try {
117125
logDebug("Cleaning RDD " + rddId)
118-
blockManagerMaster.removeRdd(rddId, false)
119-
sc.persistentRdds.remove(rddId)
126+
unpersistRDD(rddId, false)
120127
listeners.foreach(_.rddCleaned(rddId))
121128
logInfo("Cleaned RDD " + rddId)
122129
} catch {

core/src/main/scala/org/apache/spark/Dependency.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ class ShuffleDependency[K, V](
5656
override def finalize() {
5757
try {
5858
if (rdd != null) {
59-
rdd.sparkContext.cleaner.cleanShuffle(shuffleId)
59+
rdd.sparkContext.cleaner.scheduleShuffleCleanup(shuffleId)
6060
}
6161
} catch {
6262
case t: Throwable =>
6363
// Paranoia - If logError throws error as well, report to stderr.
6464
try {
6565
logError("Error in finalize", t)
6666
} catch {
67-
case _ =>
67+
case _ : Throwable =>
6868
System.err.println("Error in finalize (and could not write to logError): " + t)
6969
}
7070
} finally {

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 55 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark
2020
import java.io._
2121
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
2222

23-
import scala.collection.mutable.{HashSet, Map}
23+
import scala.collection.mutable.{HashSet, HashMap, Map}
2424
import scala.concurrent.Await
2525

2626
import akka.actor._
@@ -34,6 +34,7 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
3434
extends MapOutputTrackerMessage
3535
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
3636

37+
/** Actor class for MapOutputTrackerMaster */
3738
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
3839
extends Actor with Logging {
3940
def receive = {
@@ -50,28 +51,35 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
5051
}
5152

5253
/**
53-
* Class that keeps track of the location of the location of the map output of
54+
* Class that keeps track of the location of the map output of
5455
* a stage. This is abstract because different versions of MapOutputTracker
5556
* (driver and worker) use different HashMap to store its metadata.
5657
*/
5758
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
5859

5960
private val timeout = AkkaUtils.askTimeout(conf)
6061

61-
// Set to the MapOutputTrackerActor living on the driver
62+
/** Set to the MapOutputTrackerActor living on the driver */
6263
var trackerActor: ActorRef = _
6364

6465
/** This HashMap needs to have different storage behavior for driver and worker */
6566
protected val mapStatuses: Map[Int, Array[MapStatus]]
6667

67-
// Incremented every time a fetch fails so that client nodes know to clear
68-
// their cache of map output locations if this happens.
68+
/**
69+
* Incremented every time a fetch fails so that client nodes know to clear
70+
* their cache of map output locations if this happens.
71+
*/
6972
protected var epoch: Long = 0
7073
protected val epochLock = new java.lang.Object
7174

72-
// Send a message to the trackerActor and get its result within a default timeout, or
73-
// throw a SparkException if this fails.
74-
private def askTracker(message: Any): Any = {
75+
/** Remembers which map output locations are currently being fetched on a worker */
76+
private val fetching = new HashSet[Int]
77+
78+
/**
79+
* Send a message to the trackerActor and get its result within a default timeout, or
80+
* throw a SparkException if this fails.
81+
*/
82+
protected def askTracker(message: Any): Any = {
7583
try {
7684
val future = trackerActor.ask(message)(timeout)
7785
Await.result(future, timeout)
@@ -81,17 +89,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
8189
}
8290
}
8391

84-
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
85-
private def communicate(message: Any) {
92+
/** Send a one-way message to the trackerActor, to which we expect it to reply with true. */
93+
protected def sendTracker(message: Any) {
8694
if (askTracker(message) != true) {
8795
throw new SparkException("Error reply received from MapOutputTracker")
8896
}
8997
}
9098

91-
// Remembers which map output locations are currently being fetched on a worker
92-
private val fetching = new HashSet[Int]
93-
94-
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
99+
/**
100+
* Called from executors to get the server URIs and
101+
* output sizes of the map outputs of a given shuffle
102+
*/
95103
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
96104
val statuses = mapStatuses.get(shuffleId).orNull
97105
if (statuses == null) {
@@ -150,22 +158,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
150158
}
151159
}
152160

153-
def stop() {
154-
communicate(StopMapOutputTracker)
155-
mapStatuses.clear()
156-
trackerActor = null
157-
}
158-
159-
// Called to get current epoch number
161+
/** Called to get current epoch number */
160162
def getEpoch: Long = {
161163
epochLock.synchronized {
162164
return epoch
163165
}
164166
}
165167

166-
// Called on workers to update the epoch number, potentially clearing old outputs
167-
// because of a fetch failure. (Each worker task calls this with the latest epoch
168-
// number on the master at the time it was created.)
168+
/**
169+
* Called from executors to update the epoch number, potentially clearing old outputs
170+
* because of a fetch failure. Each worker task calls this with the latest epoch
171+
* number on the master at the time it was created.
172+
*/
169173
def updateEpoch(newEpoch: Long) {
170174
epochLock.synchronized {
171175
if (newEpoch > epoch) {
@@ -175,24 +179,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
175179
}
176180
}
177181
}
178-
}
179182

180-
/**
181-
* MapOutputTracker for the workers. This uses BoundedHashMap to keep track of
182-
* a limited number of most recently used map output information.
183-
*/
184-
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
183+
/** Unregister shuffle data */
184+
def unregisterShuffle(shuffleId: Int) {
185+
mapStatuses.remove(shuffleId)
186+
}
185187

186-
/**
187-
* Bounded HashMap for storing serialized statuses in the worker. This allows
188-
* the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be
189-
* automatically repopulated by fetching them again from the driver. Its okay to
190-
* keep the cache size small as it unlikely that there will be a very large number of
191-
* stages active simultaneously in the worker.
192-
*/
193-
protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](
194-
conf.getInt("spark.mapOutputTracker.cacheSize", 100), true
195-
)
188+
def stop() {
189+
sendTracker(StopMapOutputTracker)
190+
mapStatuses.clear()
191+
trackerActor = null
192+
}
196193
}
197194

198195
/**
@@ -202,7 +199,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
202199
private[spark] class MapOutputTrackerMaster(conf: SparkConf)
203200
extends MapOutputTracker(conf) {
204201

205-
// Cache a serialized version of the output statuses for each shuffle to send them out faster
202+
/** Cache a serialized version of the output statuses for each shuffle to send them out faster */
206203
private var cacheEpoch = epoch
207204

208205
/**
@@ -211,7 +208,6 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
211208
* by TTL-based cleaning (if set). Other than these two
212209
* scenarios, nothing should be dropped from this HashMap.
213210
*/
214-
215211
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()
216212
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]()
217213

@@ -232,13 +228,15 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
232228
}
233229
}
234230

231+
/** Register multiple map output information for the given shuffle */
235232
def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
236233
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
237234
if (changeEpoch) {
238235
incrementEpoch()
239236
}
240237
}
241238

239+
/** Unregister map output information of the given shuffle, mapper and block manager */
242240
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
243241
val arrayOpt = mapStatuses.get(shuffleId)
244242
if (arrayOpt.isDefined && arrayOpt.get != null) {
@@ -254,11 +252,17 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
254252
}
255253
}
256254

257-
def unregisterShuffle(shuffleId: Int) {
255+
/** Unregister shuffle data */
256+
override def unregisterShuffle(shuffleId: Int) {
258257
mapStatuses.remove(shuffleId)
259258
cachedSerializedStatuses.remove(shuffleId)
260259
}
261260

261+
/** Check if the given shuffle is being tracked */
262+
def containsShuffle(shuffleId: Int): Boolean = {
263+
cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
264+
}
265+
262266
def incrementEpoch() {
263267
epochLock.synchronized {
264268
epoch += 1
@@ -295,26 +299,26 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
295299
bytes
296300
}
297301

298-
def contains(shuffleId: Int): Boolean = {
299-
cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
300-
}
301-
302302
override def stop() {
303303
super.stop()
304304
metadataCleaner.cancel()
305305
cachedSerializedStatuses.clear()
306306
}
307307

308-
override def updateEpoch(newEpoch: Long) {
309-
// This might be called on the MapOutputTrackerMaster if we're running in local mode.
310-
}
311-
312308
protected def cleanup(cleanupTime: Long) {
313309
mapStatuses.clearOldValues(cleanupTime)
314310
cachedSerializedStatuses.clearOldValues(cleanupTime)
315311
}
316312
}
317313

314+
/**
315+
* MapOutputTracker for the workers, which fetches map output information from the driver's
316+
* MapOutputTrackerMaster.
317+
*/
318+
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
319+
protected val mapStatuses = new HashMap[Int, Array[MapStatus]]
320+
}
321+
318322
private[spark] object MapOutputTracker {
319323
private val LOG_BASE = 1.1
320324

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -165,18 +165,6 @@ object SparkEnv extends Logging {
165165
}
166166
}
167167

168-
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
169-
"BlockManagerMaster",
170-
new BlockManagerMasterActor(isLocal, conf)), conf)
171-
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
172-
serializer, conf, securityManager)
173-
174-
val connectionManager = blockManager.connectionManager
175-
176-
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
177-
178-
val cacheManager = new CacheManager(blockManager)
179-
180168
// Have to assign trackerActor after initialization as MapOutputTrackerActor
181169
// requires the MapOutputTracker itself
182170
val mapOutputTracker = if (isDriver) {
@@ -188,6 +176,19 @@ object SparkEnv extends Logging {
188176
"MapOutputTracker",
189177
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
190178

179+
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
180+
"BlockManagerMaster",
181+
new BlockManagerMasterActor(isLocal, conf)), conf)
182+
183+
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
184+
serializer, conf, securityManager, mapOutputTracker)
185+
186+
val connectionManager = blockManager.connectionManager
187+
188+
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
189+
190+
val cacheManager = new CacheManager(blockManager)
191+
191192
val shuffleFetcher = instantiateClass[ShuffleFetcher](
192193
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
193194

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,7 @@ abstract class RDD[T: ClassTag](
165165
*/
166166
def unpersist(blocking: Boolean = true): RDD[T] = {
167167
logInfo("Removing RDD " + id + " from persistence list")
168-
sc.env.blockManager.master.removeRdd(id, blocking)
169-
sc.persistentRdds.remove(id)
168+
sc.cleaner.unpersistRDD(id, blocking)
170169
storageLevel = StorageLevel.NONE
171170
this
172171
}
@@ -1025,14 +1024,6 @@ abstract class RDD[T: ClassTag](
10251024
checkpointData.flatMap(_.getCheckpointFile)
10261025
}
10271026

1028-
def cleanup() {
1029-
logInfo("Cleanup called on RDD " + id)
1030-
sc.cleaner.cleanRDD(id)
1031-
dependencies.filter(_.isInstanceOf[ShuffleDependency[_, _]])
1032-
.map(_.asInstanceOf[ShuffleDependency[_, _]].shuffleId)
1033-
.foreach(sc.cleaner.cleanShuffle)
1034-
}
1035-
10361027
// =======================================================================
10371028
// Other internal methods and fields
10381029
// =======================================================================
@@ -1114,14 +1105,14 @@ abstract class RDD[T: ClassTag](
11141105

11151106
override def finalize() {
11161107
try {
1117-
cleanup()
1108+
sc.cleaner.scheduleRDDCleanup(id)
11181109
} catch {
11191110
case t: Throwable =>
11201111
// Paranoia - If logError throws error as well, report to stderr.
11211112
try {
11221113
logError("Error in finalize", t)
11231114
} catch {
1124-
case _ =>
1115+
case _ : Throwable =>
11251116
System.err.println("Error in finalize (and could not write to logError): " + t)
11261117
}
11271118
} finally {

0 commit comments

Comments
 (0)