@@ -20,7 +20,7 @@ package org.apache.spark
20
20
import java .io ._
21
21
import java .util .zip .{GZIPInputStream , GZIPOutputStream }
22
22
23
- import scala .collection .mutable .{HashSet , Map }
23
+ import scala .collection .mutable .{HashSet , HashMap , Map }
24
24
import scala .concurrent .Await
25
25
26
26
import akka .actor ._
@@ -34,6 +34,7 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
34
34
extends MapOutputTrackerMessage
35
35
private [spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
36
36
37
+ /** Actor class for MapOutputTrackerMaster */
37
38
private [spark] class MapOutputTrackerMasterActor (tracker : MapOutputTrackerMaster )
38
39
extends Actor with Logging {
39
40
def receive = {
@@ -50,28 +51,35 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
50
51
}
51
52
52
53
/**
53
- * Class that keeps track of the location of the location of the map output of
54
+ * Class that keeps track of the location of the map output of
54
55
* a stage. This is abstract because different versions of MapOutputTracker
55
56
* (driver and worker) use different HashMap to store its metadata.
56
57
*/
57
58
private [spark] abstract class MapOutputTracker (conf : SparkConf ) extends Logging {
58
59
59
60
private val timeout = AkkaUtils .askTimeout(conf)
60
61
61
- // Set to the MapOutputTrackerActor living on the driver
62
+ /** Set to the MapOutputTrackerActor living on the driver */
62
63
var trackerActor : ActorRef = _
63
64
64
65
/** This HashMap needs to have different storage behavior for driver and worker */
65
66
protected val mapStatuses : Map [Int , Array [MapStatus ]]
66
67
67
- // Incremented every time a fetch fails so that client nodes know to clear
68
- // their cache of map output locations if this happens.
68
+ /**
69
+ * Incremented every time a fetch fails so that client nodes know to clear
70
+ * their cache of map output locations if this happens.
71
+ */
69
72
protected var epoch : Long = 0
70
73
protected val epochLock = new java.lang.Object
71
74
72
- // Send a message to the trackerActor and get its result within a default timeout, or
73
- // throw a SparkException if this fails.
74
- private def askTracker (message : Any ): Any = {
75
+ /** Remembers which map output locations are currently being fetched on a worker */
76
+ private val fetching = new HashSet [Int ]
77
+
78
+ /**
79
+ * Send a message to the trackerActor and get its result within a default timeout, or
80
+ * throw a SparkException if this fails.
81
+ */
82
+ protected def askTracker (message : Any ): Any = {
75
83
try {
76
84
val future = trackerActor.ask(message)(timeout)
77
85
Await .result(future, timeout)
@@ -81,17 +89,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
81
89
}
82
90
}
83
91
84
- // Send a one-way message to the trackerActor, to which we expect it to reply with true.
85
- private def communicate (message : Any ) {
92
+ /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */
93
+ protected def sendTracker (message : Any ) {
86
94
if (askTracker(message) != true ) {
87
95
throw new SparkException (" Error reply received from MapOutputTracker" )
88
96
}
89
97
}
90
98
91
- // Remembers which map output locations are currently being fetched on a worker
92
- private val fetching = new HashSet [ Int ]
93
-
94
- // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
99
+ /**
100
+ * Called from executors to get the server URIs and
101
+ * output sizes of the map outputs of a given shuffle
102
+ */
95
103
def getServerStatuses (shuffleId : Int , reduceId : Int ): Array [(BlockManagerId , Long )] = {
96
104
val statuses = mapStatuses.get(shuffleId).orNull
97
105
if (statuses == null ) {
@@ -150,22 +158,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
150
158
}
151
159
}
152
160
153
- def stop () {
154
- communicate(StopMapOutputTracker )
155
- mapStatuses.clear()
156
- trackerActor = null
157
- }
158
-
159
- // Called to get current epoch number
161
+ /** Called to get current epoch number */
160
162
def getEpoch : Long = {
161
163
epochLock.synchronized {
162
164
return epoch
163
165
}
164
166
}
165
167
166
- // Called on workers to update the epoch number, potentially clearing old outputs
167
- // because of a fetch failure. (Each worker task calls this with the latest epoch
168
- // number on the master at the time it was created.)
168
+ /**
169
+ * Called from executors to update the epoch number, potentially clearing old outputs
170
+ * because of a fetch failure. Each worker task calls this with the latest epoch
171
+ * number on the master at the time it was created.
172
+ */
169
173
def updateEpoch (newEpoch : Long ) {
170
174
epochLock.synchronized {
171
175
if (newEpoch > epoch) {
@@ -175,24 +179,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
175
179
}
176
180
}
177
181
}
178
- }
179
182
180
- /**
181
- * MapOutputTracker for the workers. This uses BoundedHashMap to keep track of
182
- * a limited number of most recently used map output information.
183
- */
184
- private [spark] class MapOutputTrackerWorker (conf : SparkConf ) extends MapOutputTracker (conf) {
183
+ /** Unregister shuffle data */
184
+ def unregisterShuffle (shuffleId : Int ) {
185
+ mapStatuses.remove(shuffleId)
186
+ }
185
187
186
- /**
187
- * Bounded HashMap for storing serialized statuses in the worker. This allows
188
- * the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be
189
- * automatically repopulated by fetching them again from the driver. Its okay to
190
- * keep the cache size small as it unlikely that there will be a very large number of
191
- * stages active simultaneously in the worker.
192
- */
193
- protected val mapStatuses = new BoundedHashMap [Int , Array [MapStatus ]](
194
- conf.getInt(" spark.mapOutputTracker.cacheSize" , 100 ), true
195
- )
188
+ def stop () {
189
+ sendTracker(StopMapOutputTracker )
190
+ mapStatuses.clear()
191
+ trackerActor = null
192
+ }
196
193
}
197
194
198
195
/**
@@ -202,7 +199,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
202
199
private [spark] class MapOutputTrackerMaster (conf : SparkConf )
203
200
extends MapOutputTracker (conf) {
204
201
205
- // Cache a serialized version of the output statuses for each shuffle to send them out faster
202
+ /** Cache a serialized version of the output statuses for each shuffle to send them out faster */
206
203
private var cacheEpoch = epoch
207
204
208
205
/**
@@ -211,7 +208,6 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
211
208
* by TTL-based cleaning (if set). Other than these two
212
209
* scenarios, nothing should be dropped from this HashMap.
213
210
*/
214
-
215
211
protected val mapStatuses = new TimeStampedHashMap [Int , Array [MapStatus ]]()
216
212
private val cachedSerializedStatuses = new TimeStampedHashMap [Int , Array [Byte ]]()
217
213
@@ -232,13 +228,15 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
232
228
}
233
229
}
234
230
231
+ /** Register multiple map output information for the given shuffle */
235
232
def registerMapOutputs (shuffleId : Int , statuses : Array [MapStatus ], changeEpoch : Boolean = false ) {
236
233
mapStatuses.put(shuffleId, Array [MapStatus ]() ++ statuses)
237
234
if (changeEpoch) {
238
235
incrementEpoch()
239
236
}
240
237
}
241
238
239
+ /** Unregister map output information of the given shuffle, mapper and block manager */
242
240
def unregisterMapOutput (shuffleId : Int , mapId : Int , bmAddress : BlockManagerId ) {
243
241
val arrayOpt = mapStatuses.get(shuffleId)
244
242
if (arrayOpt.isDefined && arrayOpt.get != null ) {
@@ -254,11 +252,17 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
254
252
}
255
253
}
256
254
257
- def unregisterShuffle (shuffleId : Int ) {
255
+ /** Unregister shuffle data */
256
+ override def unregisterShuffle (shuffleId : Int ) {
258
257
mapStatuses.remove(shuffleId)
259
258
cachedSerializedStatuses.remove(shuffleId)
260
259
}
261
260
261
+ /** Check if the given shuffle is being tracked */
262
+ def containsShuffle (shuffleId : Int ): Boolean = {
263
+ cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
264
+ }
265
+
262
266
def incrementEpoch () {
263
267
epochLock.synchronized {
264
268
epoch += 1
@@ -295,26 +299,26 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
295
299
bytes
296
300
}
297
301
298
- def contains (shuffleId : Int ): Boolean = {
299
- cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
300
- }
301
-
302
302
override def stop () {
303
303
super .stop()
304
304
metadataCleaner.cancel()
305
305
cachedSerializedStatuses.clear()
306
306
}
307
307
308
- override def updateEpoch (newEpoch : Long ) {
309
- // This might be called on the MapOutputTrackerMaster if we're running in local mode.
310
- }
311
-
312
308
protected def cleanup (cleanupTime : Long ) {
313
309
mapStatuses.clearOldValues(cleanupTime)
314
310
cachedSerializedStatuses.clearOldValues(cleanupTime)
315
311
}
316
312
}
317
313
314
+ /**
315
+ * MapOutputTracker for the workers, which fetches map output information from the driver's
316
+ * MapOutputTrackerMaster.
317
+ */
318
+ private [spark] class MapOutputTrackerWorker (conf : SparkConf ) extends MapOutputTracker (conf) {
319
+ protected val mapStatuses = new HashMap [Int , Array [MapStatus ]]
320
+ }
321
+
318
322
private [spark] object MapOutputTracker {
319
323
private val LOG_BASE = 1.1
320
324
0 commit comments