Skip to content

Commit 7768a80

Browse files
committed
[SPARK-4031] Make torrent broadcast read blocks on use.
This avoids reading torrent broadcast variables when they are referenced in the closure but not used in the closure. This is done by using a `lazy val` to read broadcast blocks cc rxin JoshRosen for review Author: Shivaram Venkataraman <[email protected]> Closes #2871 from shivaram/broadcast-read-value and squashes the following commits: 1456d65 [Shivaram Venkataraman] Use getUsedTimeMs and remove readObject d6c5ee9 [Shivaram Venkataraman] Use laxy val to implement readBroadcastBlock 0b34df7 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into broadcast-read-value 9cec507 [Shivaram Venkataraman] Test if broadcast variables are read lazily 768b40b [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into broadcast-read-value 8792ed8 [Shivaram Venkataraman] Make torrent broadcast read blocks on use. This avoids reading broadcast variables when they are referenced in the closure but not used by the code.
1 parent 0ac52e3 commit 7768a80

File tree

3 files changed

+67
-21
lines changed

3 files changed

+67
-21
lines changed

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
5656
extends Broadcast[T](id) with Logging with Serializable {
5757

5858
/**
59-
* Value of the broadcast object. On driver, this is set directly by the constructor.
60-
* On executors, this is reconstructed by [[readObject]], which builds this value by reading
61-
* blocks from the driver and/or other executors.
59+
* Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
60+
* which builds this value by reading blocks from the driver and/or other executors.
61+
*
62+
* On the driver, if the value is required, it is read lazily from the block manager.
6263
*/
63-
@transient private var _value: T = obj
64+
@transient private lazy val _value: T = readBroadcastBlock()
65+
6466
/** The compression codec to use, or None if compression is disabled */
6567
@transient private var compressionCodec: Option[CompressionCodec] = _
6668
/** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
@@ -79,22 +81,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
7981
private val broadcastId = BroadcastBlockId(id)
8082

8183
/** Total number of blocks this broadcast variable contains. */
82-
private val numBlocks: Int = writeBlocks()
84+
private val numBlocks: Int = writeBlocks(obj)
8385

84-
override protected def getValue() = _value
86+
override protected def getValue() = {
87+
_value
88+
}
8589

8690
/**
8791
* Divide the object into multiple blocks and put those blocks in the block manager.
88-
*
92+
* @param value the object to divide
8993
* @return number of blocks this broadcast variable is divided into
9094
*/
91-
private def writeBlocks(): Int = {
95+
private def writeBlocks(value: T): Int = {
9296
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
9397
// do not create a duplicate copy of the broadcast variable's value.
94-
SparkEnv.get.blockManager.putSingle(broadcastId, _value, StorageLevel.MEMORY_AND_DISK,
98+
SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK,
9599
tellMaster = false)
96100
val blocks =
97-
TorrentBroadcast.blockifyObject(_value, blockSize, SparkEnv.get.serializer, compressionCodec)
101+
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
98102
blocks.zipWithIndex.foreach { case (block, i) =>
99103
SparkEnv.get.blockManager.putBytes(
100104
BroadcastBlockId(id, "piece" + i),
@@ -157,31 +161,30 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
157161
out.defaultWriteObject()
158162
}
159163

160-
/** Used by the JVM when deserializing this object. */
161-
private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
162-
in.defaultReadObject()
164+
private def readBroadcastBlock(): T = Utils.tryOrIOException {
163165
TorrentBroadcast.synchronized {
164166
setConf(SparkEnv.get.conf)
165167
SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
166168
case Some(x) =>
167-
_value = x.asInstanceOf[T]
169+
x.asInstanceOf[T]
168170

169171
case None =>
170172
logInfo("Started reading broadcast variable " + id)
171-
val start = System.nanoTime()
173+
val startTimeMs = System.currentTimeMillis()
172174
val blocks = readBlocks()
173-
val time = (System.nanoTime() - start) / 1e9
174-
logInfo("Reading broadcast variable " + id + " took " + time + " s")
175+
logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
175176

176-
_value =
177-
TorrentBroadcast.unBlockifyObject[T](blocks, SparkEnv.get.serializer, compressionCodec)
177+
val obj = TorrentBroadcast.unBlockifyObject[T](
178+
blocks, SparkEnv.get.serializer, compressionCodec)
178179
// Store the merged copy in BlockManager so other tasks on this executor don't
179180
// need to re-fetch it.
180181
SparkEnv.get.blockManager.putSingle(
181-
broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
182+
broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
183+
obj
182184
}
183185
}
184186
}
187+
185188
}
186189

187190

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,21 @@ private[spark] object Utils extends Logging {
988988
}
989989
}
990990

991+
/**
992+
* Execute a block of code that returns a value, re-throwing any non-fatal uncaught
993+
* exceptions as IOException. This is used when implementing Externalizable and Serializable's
994+
* read and write methods, since Java's serializer will not report non-IOExceptions properly;
995+
* see SPARK-4080 for more context.
996+
*/
997+
def tryOrIOException[T](block: => T): T = {
998+
try {
999+
block
1000+
} catch {
1001+
case e: IOException => throw e
1002+
case NonFatal(t) => throw new IOException(t)
1003+
}
1004+
}
1005+
9911006
/** Default filtering function for finding call sites using `getCallSite`. */
9921007
private def coreExclusionFunction(className: String): Boolean = {
9931008
// A regular expression to match classes of the "core" Spark API that we want to skip when

core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,28 @@ import scala.util.Random
2121

2222
import org.scalatest.{Assertions, FunSuite}
2323

24-
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
24+
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkEnv}
2525
import org.apache.spark.io.SnappyCompressionCodec
26+
import org.apache.spark.rdd.RDD
2627
import org.apache.spark.serializer.JavaSerializer
2728
import org.apache.spark.storage._
2829

30+
// Dummy class that creates a broadcast variable but doesn't use it
31+
class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable {
32+
@transient val list = List(1, 2, 3, 4)
33+
val broadcast = rdd.context.broadcast(list)
34+
val bid = broadcast.id
35+
36+
def doSomething() = {
37+
rdd.map { x =>
38+
val bm = SparkEnv.get.blockManager
39+
// Check if broadcast block was fetched
40+
val isFound = bm.getLocal(BroadcastBlockId(bid)).isDefined
41+
(x, isFound)
42+
}.collect().toSet
43+
}
44+
}
45+
2946
class BroadcastSuite extends FunSuite with LocalSparkContext {
3047

3148
private val httpConf = broadcastConf("HttpBroadcastFactory")
@@ -105,6 +122,17 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
105122
}
106123
}
107124

125+
test("Test Lazy Broadcast variables with TorrentBroadcast") {
126+
val numSlaves = 2
127+
val conf = torrentConf.clone
128+
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf)
129+
val rdd = sc.parallelize(1 to numSlaves)
130+
131+
val results = new DummyBroadcastClass(rdd).doSomething()
132+
133+
assert(results.toSet === (1 to numSlaves).map(x => (x, false)).toSet)
134+
}
135+
108136
test("Unpersisting HttpBroadcast on executors only in local mode") {
109137
testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false)
110138
}

0 commit comments

Comments
 (0)