Skip to content

Commit f181560

Browse files
mateizrxin
authored andcommitted
Merge pull request alteryx#68 from mosharaf/master
Faster and stable/reliable broadcast HttpBroadcast is noticeably slow, but the alternatives (TreeBroadcast or BitTorrentBroadcast) are notoriously unreliable. The main problem with them is they try to manage the memory for the pieces of a broadcast themselves. Right now, the BroadcastManager does not know which machines the tasks reading from a broadcast variable is running and when they have finished. Consequently, we try to guess and often guess wrong, which blows up the memory usage and kills/hangs jobs. This very simple implementation solves the problem by not trying to manage the intermediate pieces; instead, it offloads that duty to the BlockManager which is quite good at juggling blocks. Otherwise, it is very similar to the BitTorrentBroadcast implementation (without fancy optimizations). And it runs much faster than HttpBroadcast we have right now. I've been using this for another project for last couple of weeks, and just today did some benchmarking against the Http one. The following shows the improvements for increasing broadcast size for cold runs. Each line represent the number of receivers. ![fix-bc-first](https://f.cloud.github.com/assets/232966/1349342/ffa149e4-36e7-11e3-9fa6-c74555829356.png) After the first broadcast is over, i.e., after JVM is wormed up and for HttpBroadcast the server is already running (I think), the following are the improvements for warm runs. ![fix-bc-succ](https://f.cloud.github.com/assets/232966/1349352/5a948bae-36e8-11e3-98ce-34f19ebd33e0.jpg) The curves are not as nice as the cold runs, but the improvements are obvious, specially for larger broadcasts and more receivers. Depending on how it goes, we should deprecate and/or remove old TreeBroadcast and BitTorrentBroadcast implementations, and hopefully, SPARK-889 will not be necessary any more. (cherry picked from commit e5316d0) Signed-off-by: Reynold Xin <[email protected]>
1 parent eaa2150 commit f181560

File tree

7 files changed

+328
-12
lines changed

7 files changed

+328
-12
lines changed
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.broadcast
19+
20+
import java.io._
21+
22+
import scala.math
23+
import scala.util.Random
24+
25+
import org.apache.spark._
26+
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
27+
import org.apache.spark.util.Utils
28+
29+
30+
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
31+
extends Broadcast[T](id) with Logging with Serializable {
32+
33+
def value = value_
34+
35+
def broadcastId = BroadcastBlockId(id)
36+
37+
TorrentBroadcast.synchronized {
38+
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
39+
}
40+
41+
@transient var arrayOfBlocks: Array[TorrentBlock] = null
42+
@transient var totalBlocks = -1
43+
@transient var totalBytes = -1
44+
@transient var hasBlocks = 0
45+
46+
if (!isLocal) {
47+
sendBroadcast()
48+
}
49+
50+
def sendBroadcast() {
51+
var tInfo = TorrentBroadcast.blockifyObject(value_)
52+
53+
totalBlocks = tInfo.totalBlocks
54+
totalBytes = tInfo.totalBytes
55+
hasBlocks = tInfo.totalBlocks
56+
57+
// Store meta-info
58+
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
59+
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
60+
TorrentBroadcast.synchronized {
61+
SparkEnv.get.blockManager.putSingle(
62+
metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
63+
}
64+
65+
// Store individual pieces
66+
for (i <- 0 until totalBlocks) {
67+
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
68+
TorrentBroadcast.synchronized {
69+
SparkEnv.get.blockManager.putSingle(
70+
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
71+
}
72+
}
73+
}
74+
75+
// Called by JVM when deserializing an object
76+
private def readObject(in: ObjectInputStream) {
77+
in.defaultReadObject()
78+
TorrentBroadcast.synchronized {
79+
SparkEnv.get.blockManager.getSingle(broadcastId) match {
80+
case Some(x) =>
81+
value_ = x.asInstanceOf[T]
82+
83+
case None =>
84+
val start = System.nanoTime
85+
logInfo("Started reading broadcast variable " + id)
86+
87+
// Initialize @transient variables that will receive garbage values from the master.
88+
resetWorkerVariables()
89+
90+
if (receiveBroadcast(id)) {
91+
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
92+
93+
// Store the merged copy in cache so that the next worker doesn't need to rebuild it.
94+
// This creates a tradeoff between memory usage and latency.
95+
// Storing copy doubles the memory footprint; not storing doubles deserialization cost.
96+
SparkEnv.get.blockManager.putSingle(
97+
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
98+
99+
// Remove arrayOfBlocks from memory once value_ is on local cache
100+
resetWorkerVariables()
101+
} else {
102+
logError("Reading broadcast variable " + id + " failed")
103+
}
104+
105+
val time = (System.nanoTime - start) / 1e9
106+
logInfo("Reading broadcast variable " + id + " took " + time + " s")
107+
}
108+
}
109+
}
110+
111+
private def resetWorkerVariables() {
112+
arrayOfBlocks = null
113+
totalBytes = -1
114+
totalBlocks = -1
115+
hasBlocks = 0
116+
}
117+
118+
def receiveBroadcast(variableID: Long): Boolean = {
119+
// Receive meta-info
120+
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
121+
var attemptId = 10
122+
while (attemptId > 0 && totalBlocks == -1) {
123+
TorrentBroadcast.synchronized {
124+
SparkEnv.get.blockManager.getSingle(metaId) match {
125+
case Some(x) =>
126+
val tInfo = x.asInstanceOf[TorrentInfo]
127+
totalBlocks = tInfo.totalBlocks
128+
totalBytes = tInfo.totalBytes
129+
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
130+
hasBlocks = 0
131+
132+
case None =>
133+
Thread.sleep(500)
134+
}
135+
}
136+
attemptId -= 1
137+
}
138+
if (totalBlocks == -1) {
139+
return false
140+
}
141+
142+
// Receive actual blocks
143+
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
144+
for (pid <- recvOrder) {
145+
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
146+
TorrentBroadcast.synchronized {
147+
SparkEnv.get.blockManager.getSingle(pieceId) match {
148+
case Some(x) =>
149+
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
150+
hasBlocks += 1
151+
SparkEnv.get.blockManager.putSingle(
152+
pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
153+
154+
case None =>
155+
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
156+
}
157+
}
158+
}
159+
160+
(hasBlocks == totalBlocks)
161+
}
162+
163+
}
164+
165+
private object TorrentBroadcast
166+
extends Logging {
167+
168+
private var initialized = false
169+
170+
def initialize(_isDriver: Boolean) {
171+
synchronized {
172+
if (!initialized) {
173+
initialized = true
174+
}
175+
}
176+
}
177+
178+
def stop() {
179+
initialized = false
180+
}
181+
182+
val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024
183+
184+
def blockifyObject[T](obj: T): TorrentInfo = {
185+
val byteArray = Utils.serialize[T](obj)
186+
val bais = new ByteArrayInputStream(byteArray)
187+
188+
var blockNum = (byteArray.length / BLOCK_SIZE)
189+
if (byteArray.length % BLOCK_SIZE != 0)
190+
blockNum += 1
191+
192+
var retVal = new Array[TorrentBlock](blockNum)
193+
var blockID = 0
194+
195+
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
196+
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
197+
var tempByteArray = new Array[Byte](thisBlockSize)
198+
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
199+
200+
retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
201+
blockID += 1
202+
}
203+
bais.close()
204+
205+
var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
206+
tInfo.hasBlocks = blockNum
207+
208+
return tInfo
209+
}
210+
211+
def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
212+
totalBytes: Int,
213+
totalBlocks: Int): T = {
214+
var retByteArray = new Array[Byte](totalBytes)
215+
for (i <- 0 until totalBlocks) {
216+
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
217+
i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
218+
}
219+
Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
220+
}
221+
222+
}
223+
224+
private[spark] case class TorrentBlock(
225+
blockID: Int,
226+
byteArray: Array[Byte])
227+
extends Serializable
228+
229+
private[spark] case class TorrentInfo(
230+
@transient arrayOfBlocks : Array[TorrentBlock],
231+
totalBlocks: Int,
232+
totalBytes: Int)
233+
extends Serializable {
234+
235+
@transient var hasBlocks = 0
236+
}
237+
238+
private[spark] class TorrentBroadcastFactory
239+
extends BroadcastFactory {
240+
241+
def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
242+
243+
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
244+
new TorrentBroadcast[T](value_, isLocal, id)
245+
246+
def stop() { TorrentBroadcast.stop() }
247+
}

core/src/main/scala/org/apache/spark/storage/BlockId.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ private[spark] sealed abstract class BlockId {
3232
def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
3333
def isRDD = isInstanceOf[RDDBlockId]
3434
def isShuffle = isInstanceOf[ShuffleBlockId]
35-
def isBroadcast = isInstanceOf[BroadcastBlockId]
35+
def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
3636

3737
override def toString = name
3838
override def hashCode = name.hashCode
@@ -55,6 +55,10 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
5555
def name = "broadcast_" + broadcastId
5656
}
5757

58+
private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
59+
def name = broadcastId.name + "_" + hType
60+
}
61+
5862
private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
5963
def name = "taskresult_" + taskId
6064
}
@@ -72,6 +76,7 @@ private[spark] object BlockId {
7276
val RDD = "rdd_([0-9]+)_([0-9]+)".r
7377
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
7478
val BROADCAST = "broadcast_([0-9]+)".r
79+
val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
7580
val TASKRESULT = "taskresult_([0-9]+)".r
7681
val STREAM = "input-([0-9]+)-([0-9]+)".r
7782
val TEST = "test_(.*)".r
@@ -84,6 +89,8 @@ private[spark] object BlockId {
8489
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
8590
case BROADCAST(broadcastId) =>
8691
BroadcastBlockId(broadcastId.toLong)
92+
case BROADCAST_HELPER(broadcastId, hType) =>
93+
BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
8794
case TASKRESULT(taskId) =>
8895
TaskResultBlockId(taskId.toLong)
8996
case STREAM(streamId, uniqueId) =>

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.io.{InputStream, OutputStream}
2121
import java.nio.{ByteBuffer, MappedByteBuffer}
2222

2323
import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
24+
import scala.util.Random
2425

2526
import akka.actor.{ActorSystem, Cancellable, Props}
2627
import akka.dispatch.{Await, Future}
@@ -269,7 +270,7 @@ private[spark] class BlockManager(
269270
}
270271

271272
/**
272-
* Actually send a UpdateBlockInfo message. Returns the mater's response,
273+
* Actually send a UpdateBlockInfo message. Returns the master's response,
273274
* which will be true if the block was successfully recorded and false if
274275
* the slave needs to re-register.
275276
*/
@@ -478,7 +479,7 @@ private[spark] class BlockManager(
478479
}
479480
logDebug("Getting remote block " + blockId)
480481
// Get locations of block
481-
val locations = master.getLocations(blockId)
482+
val locations = Random.shuffle(master.getLocations(blockId))
482483

483484
// Get block from remote locations
484485
for (loc <- locations) {

core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
227227
}
228228

229229
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
230-
if (id.executorId == "<driver>" && !isLocal) {
231-
// Got a register message from the master node; don't register it
232-
} else if (!blockManagerInfo.contains(id)) {
230+
if (!blockManagerInfo.contains(id)) {
233231
blockManagerIdByExecutor.get(id.executorId) match {
234232
case Some(manager) =>
235233
// A block manager of the same executor already exists.

core/src/test/scala/org/apache/spark/BroadcastSuite.scala

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,66 @@ package org.apache.spark
2020
import org.scalatest.FunSuite
2121

2222
class BroadcastSuite extends FunSuite with LocalSparkContext {
23-
24-
test("basic broadcast") {
23+
24+
override def afterEach() {
25+
super.afterEach()
26+
System.clearProperty("spark.broadcast.factory")
27+
}
28+
29+
test("Using HttpBroadcast locally") {
30+
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
31+
sc = new SparkContext("local", "test")
32+
val list = List(1, 2, 3, 4)
33+
val listBroadcast = sc.broadcast(list)
34+
val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
35+
assert(results.collect.toSet === Set((1, 10), (2, 10)))
36+
}
37+
38+
test("Accessing HttpBroadcast variables from multiple threads") {
39+
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
40+
sc = new SparkContext("local[10]", "test")
41+
val list = List(1, 2, 3, 4)
42+
val listBroadcast = sc.broadcast(list)
43+
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
44+
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
45+
}
46+
47+
test("Accessing HttpBroadcast variables in a local cluster") {
48+
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
49+
val numSlaves = 4
50+
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
51+
val list = List(1, 2, 3, 4)
52+
val listBroadcast = sc.broadcast(list)
53+
val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
54+
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
55+
}
56+
57+
test("Using TorrentBroadcast locally") {
58+
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
2559
sc = new SparkContext("local", "test")
2660
val list = List(1, 2, 3, 4)
2761
val listBroadcast = sc.broadcast(list)
2862
val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
2963
assert(results.collect.toSet === Set((1, 10), (2, 10)))
3064
}
3165

32-
test("broadcast variables accessed in multiple threads") {
66+
test("Accessing TorrentBroadcast variables from multiple threads") {
67+
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
3368
sc = new SparkContext("local[10]", "test")
3469
val list = List(1, 2, 3, 4)
3570
val listBroadcast = sc.broadcast(list)
3671
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
3772
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
3873
}
74+
75+
test("Accessing TorrentBroadcast variables in a local cluster") {
76+
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
77+
val numSlaves = 4
78+
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
79+
val list = List(1, 2, 3, 4)
80+
val listBroadcast = sc.broadcast(list)
81+
val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
82+
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
83+
}
84+
3985
}

0 commit comments

Comments
 (0)