Skip to content

Commit 3b11f43

Browse files
committed
Merge pull request alteryx#57 from aarondav/bid
Refactor BlockId into an actual type Converts all of our BlockId strings into actual BlockId types. Here are some advantages of doing this now: + Type safety + Code clarity - it's now obvious what the key of a shuffle or rdd block is, for instance. Additionally, appearing in tuple/map type signatures is a big readability bonus. A Seq[(String, BlockStatus)] is not very clear. Further, we can now use more Scala features, like matching on BlockId types. + Explicit usage - we can now formally tell where various BlockIds are being used (without doing string searches); this makes updating current BlockIds a much clearer process, and compiler-supported. (I'm looking at you, shuffle file consolidation.) + It will only get harder to make this change as time goes on. Downside is, of course, that this is a very invasive change touching a lot of different files, which will inevitably lead to merge conflicts for many.
2 parents 9979690 + 4a45019 commit 3b11f43

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+544
-385
lines changed

core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import io.netty.channel.ChannelHandlerContext;
2222
import io.netty.channel.ChannelInboundByteHandlerAdapter;
2323

24+
import org.apache.spark.storage.BlockId;
2425

2526
abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
2627

@@ -33,7 +34,7 @@ public boolean isComplete() {
3334
}
3435

3536
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
36-
public abstract void handleError(String blockId);
37+
public abstract void handleError(BlockId blockId);
3738

3839
@Override
3940
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {

core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
2525
import io.netty.channel.DefaultFileRegion;
2626

27+
import org.apache.spark.storage.BlockId;
2728

2829
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
2930

@@ -34,8 +35,9 @@ public FileServerHandler(PathResolver pResolver){
3435
}
3536

3637
@Override
37-
public void messageReceived(ChannelHandlerContext ctx, String blockId) {
38-
String path = pResolver.getAbsolutePath(blockId);
38+
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
39+
BlockId blockId = BlockId.apply(blockIdString);
40+
String path = pResolver.getAbsolutePath(blockId.name());
3941
// if getFilePath returns null, close the channel
4042
if (path == null) {
4143
//ctx.close();

core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import scala.collection.mutable.HashMap
2222

2323
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
2424
import org.apache.spark.serializer.Serializer
25-
import org.apache.spark.storage.BlockManagerId
25+
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
2626
import org.apache.spark.util.CompletionIterator
2727

2828

@@ -45,22 +45,21 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
4545
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
4646
}
4747

48-
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
48+
val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
4949
case (address, splits) =>
50-
(address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
50+
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
5151
}
5252

53-
def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = {
53+
def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
5454
val blockId = blockPair._1
5555
val blockOption = blockPair._2
5656
blockOption match {
5757
case Some(block) => {
5858
block.asInstanceOf[Iterator[T]]
5959
}
6060
case None => {
61-
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
6261
blockId match {
63-
case regex(shufId, mapId, _) =>
62+
case ShuffleBlockId(shufId, mapId, _) =>
6463
val address = statuses(mapId.toInt)._1
6564
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
6665
case _ =>

core/src/main/scala/org/apache/spark/CacheManager.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark
1919

2020
import scala.collection.mutable.{ArrayBuffer, HashSet}
21-
import org.apache.spark.storage.{BlockManager, StorageLevel}
21+
import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, RDDBlockId}
2222
import org.apache.spark.rdd.RDD
2323

2424

@@ -28,12 +28,12 @@ import org.apache.spark.rdd.RDD
2828
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
2929

3030
/** Keys of RDD splits that are being computed/loaded. */
31-
private val loading = new HashSet[String]()
31+
private val loading = new HashSet[RDDBlockId]()
3232

3333
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
3434
def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel)
3535
: Iterator[T] = {
36-
val key = "rdd_%d_%d".format(rdd.id, split.index)
36+
val key = RDDBlockId(rdd.id, split.index)
3737
logDebug("Looking for partition " + key)
3838
blockManager.get(key) match {
3939
case Some(values) =>
@@ -73,7 +73,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
7373
if (context.runningLocally) { return computedValues }
7474
val elements = new ArrayBuffer[Any]
7575
elements ++= computedValues
76-
blockManager.put(key, elements, storageLevel, true)
76+
blockManager.put(key, elements, storageLevel, tellMaster = true)
7777
return elements.iterator.asInstanceOf[Iterator[T]]
7878
} finally {
7979
loading.synchronized {

core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scala.collection.mutable.{ListBuffer, Map, Set}
2626
import scala.math
2727

2828
import org.apache.spark._
29-
import org.apache.spark.storage.{BlockManager, StorageLevel}
29+
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
3030
import org.apache.spark.util.Utils
3131

3232
private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
@@ -36,7 +36,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
3636

3737
def value = value_
3838

39-
def blockId: String = BlockManager.toBroadcastId(id)
39+
def blockId = BroadcastBlockId(id)
4040

4141
MultiTracker.synchronized {
4242
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)

core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
2525

2626
import org.apache.spark.{HttpServer, Logging, SparkEnv}
2727
import org.apache.spark.io.CompressionCodec
28-
import org.apache.spark.storage.{BlockManager, StorageLevel}
29-
import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashSet}
30-
28+
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
29+
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
3130

3231
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
3332
extends Broadcast[T](id) with Logging with Serializable {
3433

3534
def value = value_
3635

37-
def blockId: String = BlockManager.toBroadcastId(id)
36+
def blockId = BroadcastBlockId(id)
3837

3938
HttpBroadcast.synchronized {
4039
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
@@ -121,7 +120,7 @@ private object HttpBroadcast extends Logging {
121120
}
122121

123122
def write(id: Long, value: Any) {
124-
val file = new File(broadcastDir, "broadcast-" + id)
123+
val file = new File(broadcastDir, BroadcastBlockId(id).name)
125124
val out: OutputStream = {
126125
if (compress) {
127126
compressionCodec.compressedOutputStream(new FileOutputStream(file))
@@ -137,7 +136,7 @@ private object HttpBroadcast extends Logging {
137136
}
138137

139138
def read[T](id: Long): T = {
140-
val url = serverUri + "/broadcast-" + id
139+
val url = serverUri + "/" + BroadcastBlockId(id).name
141140
val in = {
142141
if (compress) {
143142
compressionCodec.compressedInputStream(new URL(url).openStream())

core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,19 @@ package org.apache.spark.broadcast
1919

2020
import java.io._
2121
import java.net._
22-
import java.util.{Comparator, Random, UUID}
2322

24-
import scala.collection.mutable.{ListBuffer, Map, Set}
25-
import scala.math
23+
import scala.collection.mutable.{ListBuffer, Set}
2624

2725
import org.apache.spark._
28-
import org.apache.spark.storage.{BlockManager, StorageLevel}
26+
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
2927
import org.apache.spark.util.Utils
3028

3129
private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
3230
extends Broadcast[T](id) with Logging with Serializable {
3331

3432
def value = value_
3533

36-
def blockId = BlockManager.toBroadcastId(id)
34+
def blockId = BroadcastBlockId(id)
3735

3836
MultiTracker.synchronized {
3937
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.collection.mutable.HashMap
2727

2828
import org.apache.spark.scheduler._
2929
import org.apache.spark._
30-
import org.apache.spark.storage.StorageLevel
30+
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
3131
import org.apache.spark.util.Utils
3232

3333
/**
@@ -173,7 +173,7 @@ private[spark] class Executor(
173173
val serializedResult = {
174174
if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
175175
logInfo("Storing result for " + taskId + " in local BlockManager")
176-
val blockId = "taskresult_" + taskId
176+
val blockId = TaskResultBlockId(taskId)
177177
env.blockManager.putBytes(
178178
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
179179
ser.serialize(new IndirectTaskResult[Any](blockId))

core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,18 @@ package org.apache.spark.network.netty
2020
import io.netty.buffer._
2121

2222
import org.apache.spark.Logging
23+
import org.apache.spark.storage.{TestBlockId, BlockId}
2324

2425
private[spark] class FileHeader (
2526
val fileLen: Int,
26-
val blockId: String) extends Logging {
27+
val blockId: BlockId) extends Logging {
2728

2829
lazy val buffer = {
2930
val buf = Unpooled.buffer()
3031
buf.capacity(FileHeader.HEADER_SIZE)
3132
buf.writeInt(fileLen)
32-
buf.writeInt(blockId.length)
33-
blockId.foreach((x: Char) => buf.writeByte(x))
33+
buf.writeInt(blockId.name.length)
34+
blockId.name.foreach((x: Char) => buf.writeByte(x))
3435
//padding the rest of header
3536
if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
3637
buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
@@ -57,18 +58,15 @@ private[spark] object FileHeader {
5758
for (i <- 1 to idLength) {
5859
idBuilder += buf.readByte().asInstanceOf[Char]
5960
}
60-
val blockId = idBuilder.toString()
61+
val blockId = BlockId(idBuilder.toString())
6162
new FileHeader(length, blockId)
6263
}
6364

64-
65-
def main (args:Array[String]){
66-
67-
val header = new FileHeader(25,"block_0");
68-
val buf = header.buffer;
69-
val newheader = FileHeader.create(buf);
70-
System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
71-
65+
def main (args:Array[String]) {
66+
val header = new FileHeader(25, TestBlockId("my_block"))
67+
val buf = header.buffer
68+
val newHeader = FileHeader.create(buf)
69+
System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen)
7270
}
7371
}
7472

core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ import org.apache.spark.Logging
2727
import org.apache.spark.network.ConnectionManagerId
2828

2929
import scala.collection.JavaConverters._
30+
import org.apache.spark.storage.BlockId
3031

3132

3233
private[spark] class ShuffleCopier extends Logging {
3334

34-
def getBlock(host: String, port: Int, blockId: String,
35-
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
35+
def getBlock(host: String, port: Int, blockId: BlockId,
36+
resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
3637

3738
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
3839
val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
@@ -41,7 +42,7 @@ private[spark] class ShuffleCopier extends Logging {
4142
try {
4243
fc.init()
4344
fc.connect(host, port)
44-
fc.sendRequest(blockId)
45+
fc.sendRequest(blockId.name)
4546
fc.waitForClose()
4647
fc.close()
4748
} catch {
@@ -53,14 +54,14 @@ private[spark] class ShuffleCopier extends Logging {
5354
}
5455
}
5556

56-
def getBlock(cmId: ConnectionManagerId, blockId: String,
57-
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
57+
def getBlock(cmId: ConnectionManagerId, blockId: BlockId,
58+
resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
5859
getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
5960
}
6061

6162
def getBlocks(cmId: ConnectionManagerId,
62-
blocks: Seq[(String, Long)],
63-
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
63+
blocks: Seq[(BlockId, Long)],
64+
resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
6465

6566
for ((blockId, size) <- blocks) {
6667
getBlock(cmId, blockId, resultCollectCallback)
@@ -71,22 +72,22 @@ private[spark] class ShuffleCopier extends Logging {
7172

7273
private[spark] object ShuffleCopier extends Logging {
7374

74-
private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
75+
private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit)
7576
extends FileClientHandler with Logging {
7677

7778
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
7879
logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
7980
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
8081
}
8182

82-
override def handleError(blockId: String) {
83+
override def handleError(blockId: BlockId) {
8384
if (!isComplete) {
8485
resultCollectCallBack(blockId, -1, null)
8586
}
8687
}
8788
}
8889

89-
def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
90+
def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) {
9091
if (size != -1) {
9192
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
9293
}
@@ -99,20 +100,20 @@ private[spark] object ShuffleCopier extends Logging {
99100
}
100101
val host = args(0)
101102
val port = args(1).toInt
102-
val file = args(2)
103+
val blockId = BlockId(args(2))
103104
val threads = if (args.length > 3) args(3).toInt else 10
104105

105106
val copiers = Executors.newFixedThreadPool(80)
106107
val tasks = (for (i <- Range(0, threads)) yield {
107108
Executors.callable(new Runnable() {
108109
def run() {
109110
val copier = new ShuffleCopier()
110-
copier.getBlock(host, port, file, echoResultCollectCallBack)
111+
copier.getBlock(host, port, blockId, echoResultCollectCallBack)
111112
}
112113
})
113114
}).asJava
114115
copiers.invokeAll(tasks)
115-
copiers.shutdown
116+
copiers.shutdown()
116117
System.exit(0)
117118
}
118119
}

0 commit comments

Comments
 (0)