Skip to content

[SPARK-4454] Properly synchronize accesses to DAGScheduler cacheLocs map #4660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ class DAGScheduler(

private[scheduler] val activeJobs = new HashSet[ActiveJob]

// Contains the locations that each RDD's partitions are cached on
/**
* Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids
* and its values are arrays indexed by partition numbers. Each array value is the set of
* locations where that RDD partition is cached.
*
* All accesses to this map should be guarded by synchronizing on it (see SPARK-4454).
*/
private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]

// For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
Expand Down Expand Up @@ -183,18 +189,17 @@ class DAGScheduler(
eventProcessLoop.post(TaskSetFailed(taskSet, reason))
}

private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
if (!cacheLocs.contains(rdd.id)) {
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = cacheLocs.synchronized {
cacheLocs.getOrElseUpdate(rdd.id, {
val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
cacheLocs(rdd.id) = blockIds.map { id =>
blockIds.map { id =>
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
}
}
cacheLocs(rdd.id)
})
}

private def clearCacheLocs() {
private def clearCacheLocs(): Unit = cacheLocs.synchronized {
cacheLocs.clear()
}

Expand Down Expand Up @@ -1276,17 +1281,26 @@ class DAGScheduler(
}

/**
* Synchronized method that might be called from other threads.
* Gets the locality information associated with a partition of a particular RDD.
*
* This method is thread-safe and is called from both DAGScheduler and SparkContext.
*
* @param rdd whose partitions are to be looked at
* @param partition to lookup locality information for
* @return list of machines that are preferred by the partition
*/
private[spark]
def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized {
def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
getPreferredLocsInternal(rdd, partition, new HashSet)
}

/** Recursive implementation for getPreferredLocs. */
/**
* Recursive implementation for getPreferredLocs.
*
* This method is thread-safe because it only accesses DAGScheduler state through thread-safe
* methods (getCacheLocs()); please be careful when modifying this method, because any new
* DAGScheduler state accessed by it may require additional synchronization.
*/
private def getPreferredLocsInternal(
rdd: RDD[_],
partition: Int,
Expand Down