Skip to content

Commit e95479c

Browse files
committed
Add tests for unpersisting broadcast
There is not currently a way to query the blocks on the executors, an operation that is deceptively simple to accomplish. This commit adds this mechanism in order to verify that blocks are in fact persisted/unpersisted on the executors in the tests.
1 parent 544ac86 commit e95479c

File tree

9 files changed

+309
-63
lines changed

9 files changed

+309
-63
lines changed

core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,26 @@ import java.io.Serializable
4848
* @tparam T Type of the data contained in the broadcast variable.
4949
*/
5050
abstract class Broadcast[T](val id: Long) extends Serializable {
51+
52+
/**
53+
* Whether this Broadcast is actually usable. This should be false once persisted state is
54+
* removed from the driver.
55+
*/
56+
protected var isValid: Boolean = true
57+
5158
def value: T
5259

5360
/**
54-
* Remove all persisted state associated with this broadcast.
61+
* Remove all persisted state associated with this broadcast. Overriding implementations
62+
* should set isValid to false if persisted state is also removed from the driver.
63+
*
5564
* @param removeFromDriver Whether to remove state from the driver.
65+
* If true, the resulting broadcast should no longer be valid.
5666
*/
5767
def unpersist(removeFromDriver: Boolean)
5868

59-
// We cannot have an abstract readObject here due to some weird issues with
60-
// readObject having to be 'private' in sub-classes.
69+
// We cannot define abstract readObject and writeObject here due to some weird issues
70+
// with these methods having to be 'private' in sub-classes.
6171

6272
override def toString = "Broadcast(" + id + ")"
6373
}

core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
package org.apache.spark.broadcast
1919

20-
import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
21-
import java.net.{URL, URLConnection, URI}
20+
import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream}
21+
import java.net.{URI, URL, URLConnection}
2222
import java.util.concurrent.TimeUnit
2323

2424
import it.unimi.dsi.fastutil.io.{FastBufferedInputStream, FastBufferedOutputStream}
@@ -49,10 +49,17 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
4949
* @param removeFromDriver Whether to remove state from the driver.
5050
*/
5151
override def unpersist(removeFromDriver: Boolean) {
52+
isValid = !removeFromDriver
5253
HttpBroadcast.unpersist(id, removeFromDriver)
5354
}
5455

55-
// Called by JVM when deserializing an object
56+
// Used by the JVM when serializing this object
57+
private def writeObject(out: ObjectOutputStream) {
58+
assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!")
59+
out.defaultWriteObject()
60+
}
61+
62+
// Used by the JVM when deserializing this object
5663
private def readObject(in: ObjectInputStream) {
5764
in.defaultReadObject()
5865
HttpBroadcast.synchronized {

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
package org.apache.spark.broadcast
1919

20-
import java.io._
20+
import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}
2121

2222
import scala.math
2323
import scala.util.Random
2424

25-
import org.apache.spark._
25+
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
2626
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
2727
import org.apache.spark.util.Utils
2828

@@ -76,10 +76,17 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
7676
* @param removeFromDriver Whether to remove state from the driver.
7777
*/
7878
override def unpersist(removeFromDriver: Boolean) {
79+
isValid = !removeFromDriver
7980
TorrentBroadcast.unpersist(id, removeFromDriver)
8081
}
8182

82-
// Called by JVM when deserializing an object
83+
// Used by the JVM when serializing this object
84+
private def writeObject(out: ObjectOutputStream) {
85+
assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!")
86+
out.defaultWriteObject()
87+
}
88+
89+
// Used by the JVM when deserializing this object
8390
private def readObject(in: ObjectInputStream) {
8491
in.defaultReadObject()
8592
TorrentBroadcast.synchronized {

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props}
2929
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
3030
import sun.nio.ch.DirectBuffer
3131

32-
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException, MapOutputTracker}
32+
import org.apache.spark._
3333
import org.apache.spark.io.CompressionCodec
3434
import org.apache.spark.network._
3535
import org.apache.spark.serializer.Serializer
@@ -58,7 +58,7 @@ private[spark] class BlockManager(
5858

5959
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
6060

61-
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
61+
private[storage] val memoryStore = new MemoryStore(this, maxMemory)
6262
private[storage] val diskStore = new DiskStore(this, diskBlockManager)
6363

6464
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
@@ -210,9 +210,9 @@ private[spark] class BlockManager(
210210
}
211211

212212
/**
213-
* Get storage level of local block. If no info exists for the block, then returns null.
213+
* Get storage level of local block. If no info exists for the block, return None.
214214
*/
215-
def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
215+
def getLevel(blockId: BlockId): Option[StorageLevel] = blockInfo.get(blockId).map(_.level)
216216

217217
/**
218218
* Tell the master about the current storage status of a block. This will send a block update
@@ -496,9 +496,8 @@ private[spark] class BlockManager(
496496

497497
/**
498498
* A short circuited method to get a block writer that can write data directly to disk.
499-
* The Block will be appended to the File specified by filename.
500-
* This is currently used for writing shuffle files out. Callers should handle error
501-
* cases.
499+
* The Block will be appended to the File specified by filename. This is currently used for
500+
* writing shuffle files out. Callers should handle error cases.
502501
*/
503502
def getDiskWriter(
504503
blockId: BlockId,
@@ -816,8 +815,7 @@ private[spark] class BlockManager(
816815
* @return The number of blocks removed.
817816
*/
818817
def removeRdd(rddId: Int): Int = {
819-
// TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
820-
// from RDD.id to blocks.
818+
// TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks.
821819
logInfo("Removing RDD " + rddId)
822820
val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
823821
blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) }
@@ -827,13 +825,13 @@ private[spark] class BlockManager(
827825
/**
828826
* Remove all blocks belonging to the given broadcast.
829827
*/
830-
def removeBroadcast(broadcastId: Long) {
828+
def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) {
831829
logInfo("Removing broadcast " + broadcastId)
832830
val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect {
833831
case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid
834832
case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid
835833
}
836-
blocksToRemove.foreach { blockId => removeBlock(blockId) }
834+
blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) }
837835
}
838836

839837
/**

core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,24 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
147147
askDriverWithReply[Array[StorageStatus]](GetStorageStatus)
148148
}
149149

150+
/**
151+
* Mainly for testing. Ask the driver to query all executors for their storage levels
152+
* regarding this block. This provides an avenue for the driver to learn the storage
153+
* levels of blocks it has not been informed of.
154+
*
155+
* WARNING: This could lead to deadlocks if there are any outstanding messages the
156+
* executors are already expecting from the driver. In this case, while the driver is
157+
* waiting for the executors to respond to its GetStorageLevel query, the executors
158+
* are also waiting for a response from the driver to a prior message.
159+
*
160+
* The interim solution is to wait for a brief window of time to pass before asking.
161+
* This should suffice, since this mechanism is largely introduced for testing only.
162+
*/
163+
def askForStorageLevels(blockId: BlockId, waitTimeMs: Long = 1000) = {
164+
Thread.sleep(waitTimeMs)
165+
askDriverWithReply[Map[BlockManagerId, StorageLevel]](AskForStorageLevels(blockId))
166+
}
167+
150168
/** Stop the driver actor, called only on the Spark driver node */
151169
def stop() {
152170
if (driverActor != null) {

core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.util.{HashMap => JHashMap}
2121

2222
import scala.collection.mutable
2323
import scala.collection.JavaConversions._
24-
import scala.concurrent.Future
24+
import scala.concurrent.{Await, Future}
2525
import scala.concurrent.duration._
2626

2727
import akka.actor.{Actor, ActorRef, Cancellable}
@@ -126,6 +126,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
126126
case HeartBeat(blockManagerId) =>
127127
sender ! heartBeat(blockManagerId)
128128

129+
case AskForStorageLevels(blockId) =>
130+
sender ! askForStorageLevels(blockId)
131+
129132
case other =>
130133
logWarning("Got unknown message: " + other)
131134
}
@@ -158,6 +161,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
158161
blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg }
159162
}
160163

164+
/**
165+
* Delegate RemoveBroadcast messages to each BlockManager because the master may not notified
166+
* of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed
167+
* from the executors, but not from the driver.
168+
*/
161169
private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) {
162170
// TODO(aor): Consolidate usages of <driver>
163171
val removeMsg = RemoveBroadcast(broadcastId)
@@ -246,6 +254,19 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
246254
}.toArray
247255
}
248256

257+
// For testing. Ask all block managers for the given block's local storage level, if any.
258+
private def askForStorageLevels(blockId: BlockId): Map[BlockManagerId, StorageLevel] = {
259+
val getStorageLevel = GetStorageLevel(blockId)
260+
blockManagerInfo.values.flatMap { info =>
261+
val future = info.slaveActor.ask(getStorageLevel)(akkaTimeout)
262+
val result = Await.result(future, akkaTimeout)
263+
if (result != null) {
264+
// If the block does not exist on the slave, the slave replies None
265+
result.asInstanceOf[Option[StorageLevel]].map { reply => (info.blockManagerId, reply) }
266+
} else None
267+
}.toMap
268+
}
269+
249270
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
250271
if (!blockManagerInfo.contains(id)) {
251272
blockManagerIdByExecutor.get(id.executorId) match {
@@ -329,6 +350,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
329350
// Note that this logic will select the same node multiple times if there aren't enough peers
330351
Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq
331352
}
353+
332354
}
333355

334356

core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ private[storage] object BlockManagerMessages {
4343
case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true)
4444
extends ToBlockManagerSlave
4545

46+
// For testing. Ask the slave for the block's storage level.
47+
case class GetStorageLevel(blockId: BlockId) extends ToBlockManagerSlave
48+
4649

4750
//////////////////////////////////////////////////////////////////////////////////
4851
// Messages from slaves to the master.
@@ -116,4 +119,8 @@ private[storage] object BlockManagerMessages {
116119
case object ExpireDeadHosts extends ToBlockManagerMaster
117120

118121
case object GetStorageStatus extends ToBlockManagerMaster
122+
123+
// For testing. Have the master ask all slaves for the given block's storage level.
124+
case class AskForStorageLevels(blockId: BlockId) extends ToBlockManagerMaster
125+
119126
}

core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ class BlockManagerSlaveActor(
4747
mapOutputTracker.unregisterShuffle(shuffleId)
4848
}
4949

50-
case RemoveBroadcast(broadcastId, _) =>
51-
blockManager.removeBroadcast(broadcastId)
50+
case RemoveBroadcast(broadcastId, removeFromDriver) =>
51+
blockManager.removeBroadcast(broadcastId, removeFromDriver)
52+
53+
case GetStorageLevel(blockId) =>
54+
sender ! blockManager.getLevel(blockId)
5255
}
5356
}

0 commit comments

Comments
 (0)