Skip to content

Commit c7ccef1

Browse files
committed
Merge branch 'bc-unpersist-merge' of github.com:ignatich/incubator-spark into cleanup
Conflicts: core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
2 parents 6c9dcf6 + 80dd977 commit c7ccef1

File tree

8 files changed

+202
-43
lines changed

8 files changed

+202
-43
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,8 +641,13 @@ class SparkContext(
641641
* Broadcast a read-only variable to the cluster, returning a
642642
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
643643
* The variable will be sent to each cluster only once.
644+
*
645+
* If `registerBlocks` is true, workers will notify driver about blocks they create
646+
* and these blocks will be dropped when `unpersist` method of the broadcast variable is called.
644647
*/
645-
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
648+
def broadcast[T](value: T, registerBlocks: Boolean = false) = {
649+
env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks)
650+
}
646651

647652
/**
648653
* Add a file to be downloaded with this Spark job on every node.

core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ import org.apache.spark._
5353
abstract class Broadcast[T](val id: Long) extends Serializable {
5454
def value: T
5555

56+
/**
57+
* Removes all blocks of this broadcast from memory (and disk if removeSource is true).
58+
*
59+
* @param removeSource Whether to remove data from disk as well.
60+
* Will cause errors if broadcast is accessed on workers afterwards
61+
* (e.g. in case of RDD re-computation due to executor failure).
62+
*/
63+
def unpersist(removeSource: Boolean = false)
64+
5665
// We cannot have an abstract readObject here due to some weird issues with
5766
// readObject having to be 'private' in sub-classes.
5867

@@ -92,8 +101,8 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager:
92101

93102
private val nextBroadcastId = new AtomicLong(0)
94103

95-
def newBroadcast[T](value_ : T, isLocal: Boolean) =
96-
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
104+
def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) =
105+
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks)
97106

98107
def isDriver = _isDriver
99108
}

core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.SparkConf
2727
* entire Spark job.
2828
*/
2929
trait BroadcastFactory {
30-
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
31-
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
30+
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
31+
def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T]
3232
def stop(): Unit
3333
}

core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,24 @@ import org.apache.spark.io.CompressionCodec
2929
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
3030
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
3131

32-
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
32+
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
3333
extends Broadcast[T](id) with Logging with Serializable {
3434

3535
def value = value_
3636

37+
def unpersist(removeSource: Boolean) {
38+
HttpBroadcast.synchronized {
39+
SparkEnv.get.blockManager.master.removeBlock(blockId)
40+
SparkEnv.get.blockManager.removeBlock(blockId)
41+
}
42+
43+
if (removeSource) {
44+
HttpBroadcast.synchronized {
45+
HttpBroadcast.cleanupById(id)
46+
}
47+
}
48+
}
49+
3750
def blockId = BroadcastBlockId(id)
3851

3952
HttpBroadcast.synchronized {
@@ -54,7 +67,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
5467
logInfo("Started reading broadcast variable " + id)
5568
val start = System.nanoTime
5669
value_ = HttpBroadcast.read[T](id)
57-
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
70+
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)
5871
val time = (System.nanoTime - start) / 1e9
5972
logInfo("Reading broadcast variable " + id + " took " + time + " s")
6073
}
@@ -71,8 +84,8 @@ class HttpBroadcastFactory extends BroadcastFactory {
7184
HttpBroadcast.initialize(isDriver, conf, securityMgr)
7285
}
7386

74-
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
75-
new HttpBroadcast[T](value_, isLocal, id)
87+
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
88+
new HttpBroadcast[T](value_, isLocal, id, registerBlocks)
7689

7790
def stop() { HttpBroadcast.stop() }
7891
}
@@ -136,8 +149,10 @@ private object HttpBroadcast extends Logging {
136149
logInfo("Broadcast server started at " + serverUri)
137150
}
138151

152+
def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)
153+
139154
def write(id: Long, value: Any) {
140-
val file = new File(broadcastDir, BroadcastBlockId(id).name)
155+
val file = getFile(id)
141156
val out: OutputStream = {
142157
if (compress) {
143158
compressionCodec.compressedOutputStream(new FileOutputStream(file))
@@ -183,20 +198,30 @@ private object HttpBroadcast extends Logging {
183198
obj
184199
}
185200

201+
def deleteFile(fileName: String) {
202+
try {
203+
new File(fileName).delete()
204+
logInfo("Deleted broadcast file '" + fileName + "'")
205+
} catch {
206+
case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e)
207+
}
208+
}
209+
186210
def cleanup(cleanupTime: Long) {
187211
val iterator = files.internalMap.entrySet().iterator()
188212
while(iterator.hasNext) {
189213
val entry = iterator.next()
190214
val (file, time) = (entry.getKey, entry.getValue)
191215
if (time < cleanupTime) {
192-
try {
193-
iterator.remove()
194-
new File(file.toString).delete()
195-
logInfo("Deleted broadcast file '" + file + "'")
196-
} catch {
197-
case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
198-
}
216+
iterator.remove()
217+
deleteFile(file)
199218
}
200219
}
201220
}
221+
222+
def cleanupById(id: Long) {
223+
val file = getFile(id).getAbsolutePath
224+
files.internalMap.remove(file)
225+
deleteFile(file)
226+
}
202227
}

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,68 @@ import org.apache.spark._
2626
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
2727
import org.apache.spark.util.Utils
2828

29-
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
30-
extends Broadcast[T](id) with Logging with Serializable {
29+
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
30+
extends Broadcast[T](id) with Logging with Serializable {
3131

3232
def value = value_
3333

34+
def unpersist(removeSource: Boolean) {
35+
TorrentBroadcast.synchronized {
36+
SparkEnv.get.blockManager.master.removeBlock(broadcastId)
37+
SparkEnv.get.blockManager.removeBlock(broadcastId)
38+
}
39+
40+
if (!removeSource) {
41+
//We can't tell BlockManager master to remove blocks from all nodes except driver,
42+
//so we need to save them here in order to store them on disk later.
43+
//This may be inefficient if blocks were already dropped to disk,
44+
//but since unpersist is supposed to be called right after working with
45+
//a broadcast this should not happen (and getting them from memory is cheap).
46+
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
47+
48+
for (pid <- 0 until totalBlocks) {
49+
val pieceId = pieceBlockId(pid)
50+
TorrentBroadcast.synchronized {
51+
SparkEnv.get.blockManager.getSingle(pieceId) match {
52+
case Some(x) =>
53+
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
54+
case None =>
55+
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
56+
}
57+
}
58+
}
59+
}
60+
61+
for (pid <- 0 until totalBlocks) {
62+
TorrentBroadcast.synchronized {
63+
SparkEnv.get.blockManager.master.removeBlock(pieceBlockId(pid))
64+
}
65+
}
66+
67+
if (removeSource) {
68+
TorrentBroadcast.synchronized {
69+
SparkEnv.get.blockManager.removeBlock(metaId)
70+
}
71+
} else {
72+
TorrentBroadcast.synchronized {
73+
SparkEnv.get.blockManager.dropFromMemory(metaId)
74+
}
75+
76+
for (i <- 0 until totalBlocks) {
77+
val pieceId = pieceBlockId(i)
78+
TorrentBroadcast.synchronized {
79+
SparkEnv.get.blockManager.putSingle(
80+
pieceId, arrayOfBlocks(i), StorageLevel.DISK_ONLY, true)
81+
}
82+
}
83+
arrayOfBlocks = null
84+
}
85+
}
86+
3487
def broadcastId = BroadcastBlockId(id)
88+
private def metaId = BroadcastHelperBlockId(broadcastId, "meta")
89+
private def pieceBlockId(pid: Int) = BroadcastHelperBlockId(broadcastId, "piece" + pid)
90+
private def pieceIds = Array.iterate(0, totalBlocks)(_ + 1).toList
3591

3692
TorrentBroadcast.synchronized {
3793
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
@@ -54,7 +110,6 @@ extends Broadcast[T](id) with Logging with Serializable {
54110
hasBlocks = tInfo.totalBlocks
55111

56112
// Store meta-info
57-
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
58113
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
59114
TorrentBroadcast.synchronized {
60115
SparkEnv.get.blockManager.putSingle(
@@ -63,7 +118,7 @@ extends Broadcast[T](id) with Logging with Serializable {
63118

64119
// Store individual pieces
65120
for (i <- 0 until totalBlocks) {
66-
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
121+
val pieceId = pieceBlockId(i)
67122
TorrentBroadcast.synchronized {
68123
SparkEnv.get.blockManager.putSingle(
69124
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
@@ -93,7 +148,7 @@ extends Broadcast[T](id) with Logging with Serializable {
93148
// This creates a tradeoff between memory usage and latency.
94149
// Storing copy doubles the memory footprint; not storing doubles deserialization cost.
95150
SparkEnv.get.blockManager.putSingle(
96-
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
151+
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)
97152

98153
// Remove arrayOfBlocks from memory once value_ is on local cache
99154
resetWorkerVariables()
@@ -116,7 +171,6 @@ extends Broadcast[T](id) with Logging with Serializable {
116171

117172
def receiveBroadcast(variableID: Long): Boolean = {
118173
// Receive meta-info
119-
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
120174
var attemptId = 10
121175
while (attemptId > 0 && totalBlocks == -1) {
122176
TorrentBroadcast.synchronized {
@@ -139,9 +193,9 @@ extends Broadcast[T](id) with Logging with Serializable {
139193
}
140194

141195
// Receive actual blocks
142-
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
196+
val recvOrder = new Random().shuffle(pieceIds)
143197
for (pid <- recvOrder) {
144-
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
198+
val pieceId = pieceBlockId(pid)
145199
TorrentBroadcast.synchronized {
146200
SparkEnv.get.blockManager.getSingle(pieceId) match {
147201
case Some(x) =>
@@ -245,8 +299,8 @@ class TorrentBroadcastFactory extends BroadcastFactory {
245299
TorrentBroadcast.initialize(isDriver, conf)
246300
}
247301

248-
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
249-
new TorrentBroadcast[T](value_, isLocal, id)
302+
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
303+
new TorrentBroadcast[T](value_, isLocal, id, registerBlocks)
250304

251305
def stop() { TorrentBroadcast.stop() }
252306
}

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ private[spark] class BlockManager(
209209
}
210210
}
211211

212+
/**
213+
* For testing. Returns number of blocks BlockManager knows about that are in memory.
214+
*/
215+
def numberOfBlocksInMemory() = blockInfo.keys.count(memoryStore.contains(_))
216+
212217
/**
213218
* Get storage level of local block. If no info exists for the block, then returns null.
214219
*/
@@ -812,6 +817,13 @@ private[spark] class BlockManager(
812817
}
813818

814819
/**
820+
* Drop a block from memory, possibly putting it on disk if applicable.
821+
*/
822+
def dropFromMemory(blockId: BlockId) {
823+
memoryStore.asInstanceOf[MemoryStore].dropFromMemory(blockId)
824+
}
825+
826+
/**
815827
* Remove all blocks belonging to the given RDD.
816828
* @return The number of blocks removed.
817829
*/

core/src/main/scala/org/apache/spark/storage/MemoryStore.scala

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,27 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
210210
}
211211

212212
/**
213-
* Try to free up a given amount of space to store a particular block, but can fail if
214-
* either the block is bigger than our memory or it would require replacing another block
215-
* from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
213+
* Drop a block from memory, possibly putting it on disk if applicable.
214+
*/
215+
def dropFromMemory(blockId: BlockId) {
216+
val entry = entries.synchronized { entries.get(blockId) }
217+
// This should never be null if called from ensureFreeSpace as only one
218+
// thread should be dropping blocks and removing entries.
219+
// However the check is required in other cases.
220+
if (entry != null) {
221+
val data = if (entry.deserialized) {
222+
Left(entry.value.asInstanceOf[ArrayBuffer[Any]])
223+
} else {
224+
Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
225+
}
226+
blockManager.dropFromMemory(blockId, data)
227+
}
228+
}
229+
230+
/**
231+
* Tries to free up a given amount of space to store a particular block, but can fail and return
232+
* false if either the block is bigger than our memory or it would require replacing another
233+
* block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
216234
* don't fit into memory that we want to avoid).
217235
*
218236
* Assume that a lock is held by the caller to ensure only one thread is dropping blocks.
@@ -254,19 +272,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
254272
if (maxMemory - (currentMemory - selectedMemory) >= space) {
255273
logInfo(selectedBlocks.size + " blocks selected for dropping")
256274
for (blockId <- selectedBlocks) {
257-
val entry = entries.synchronized { entries.get(blockId) }
258-
// This should never be null as only one thread should be dropping
259-
// blocks and removing entries. However the check is still here for
260-
// future safety.
261-
if (entry != null) {
262-
val data = if (entry.deserialized) {
263-
Left(entry.value.asInstanceOf[ArrayBuffer[Any]])
264-
} else {
265-
Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
266-
}
267-
val droppedBlockStatus = blockManager.dropFromMemory(blockId, data)
268-
droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
269-
}
275+
dropFromMemory(blockId)
270276
}
271277
return ResultWithDroppedBlocks(success = true, droppedBlocks)
272278
} else {

0 commit comments

Comments
 (0)