Skip to content

Commit d4fa04e

Browse files
aarondavrxin
authored andcommitted
[SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages
This PR elimiantes the network package's usage of the Java serializer and replaces it with Encodable, which is a lightweight binary protocol. Each message is preceded by a type id, which will allow us to change messages (by only adding new ones), or to change the format entirely by switching to a special id (such as -1). This protocol has the advantage over Java that we can guarantee that messages will remain compatible across compiled versions and JVMs, though it does not provide a clean way to do schema migration. In the future, it may be good to use a more heavy-weight serialization format like protobuf, thrift, or avro, but these all add several dependencies which are unnecessary at the present time. Additionally this unifies the RPC messages of NettyBlockTransferService and ExternalShuffleClient. Author: Aaron Davidson <[email protected]> Closes #3146 from aarondav/free and squashes the following commits: ed1102a [Aaron Davidson] Remove some unused imports b8e2a49 [Aaron Davidson] Add appId to test 538f2a3 [Aaron Davidson] [SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages
1 parent 3abdb1b commit d4fa04e

30 files changed

+640
-284
lines changed

core/src/main/scala/org/apache/spark/network/BlockTransferService.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
7373
def uploadBlock(
7474
hostname: String,
7575
port: Int,
76+
execId: String,
7677
blockId: BlockId,
7778
blockData: ManagedBuffer,
7879
level: StorageLevel): Future[Unit]
@@ -110,9 +111,10 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
110111
def uploadBlockSync(
111112
hostname: String,
112113
port: Int,
114+
execId: String,
113115
blockId: BlockId,
114116
blockData: ManagedBuffer,
115117
level: StorageLevel): Unit = {
116-
Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
118+
Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf)
117119
}
118120
}

core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,10 @@ import org.apache.spark.network.BlockDataManager
2626
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
2727
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
2828
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
29-
import org.apache.spark.network.shuffle.ShuffleStreamHandle
29+
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
3030
import org.apache.spark.serializer.Serializer
3131
import org.apache.spark.storage.{BlockId, StorageLevel}
3232

33-
object NettyMessages {
34-
/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
35-
case class OpenBlocks(blockIds: Seq[BlockId])
36-
37-
/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
38-
case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel)
39-
}
40-
4133
/**
4234
* Serves requests to open blocks by simply registering one chunk per block requested.
4335
* Handles opening and uploading arbitrary BlockManager blocks.
@@ -50,28 +42,29 @@ class NettyBlockRpcServer(
5042
blockManager: BlockDataManager)
5143
extends RpcHandler with Logging {
5244

53-
import NettyMessages._
54-
5545
private val streamManager = new OneForOneStreamManager()
5646

5747
override def receive(
5848
client: TransportClient,
5949
messageBytes: Array[Byte],
6050
responseContext: RpcResponseCallback): Unit = {
61-
val ser = serializer.newInstance()
62-
val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes))
51+
val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes)
6352
logTrace(s"Received request: $message")
6453

6554
message match {
66-
case OpenBlocks(blockIds) =>
67-
val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData)
55+
case openBlocks: OpenBlocks =>
56+
val blocks: Seq[ManagedBuffer] =
57+
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
6858
val streamId = streamManager.registerStream(blocks.iterator)
6959
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
70-
responseContext.onSuccess(
71-
ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array())
60+
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
7261

73-
case UploadBlock(blockId, blockData, level) =>
74-
blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level)
62+
case uploadBlock: UploadBlock =>
63+
// StorageLevel is serialized as bytes using our JavaSerializer.
64+
val level: StorageLevel =
65+
serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
66+
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
67+
blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level)
7568
responseContext.onSuccess(new Array[Byte](0))
7669
}
7770
}

core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ import org.apache.spark.{SecurityManager, SparkConf}
2424
import org.apache.spark.network._
2525
import org.apache.spark.network.buffer.ManagedBuffer
2626
import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
27-
import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
2827
import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
2928
import org.apache.spark.network.server._
3029
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
30+
import org.apache.spark.network.shuffle.protocol.UploadBlock
3131
import org.apache.spark.serializer.JavaSerializer
3232
import org.apache.spark.storage.{BlockId, StorageLevel}
3333
import org.apache.spark.util.Utils
@@ -46,6 +46,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
4646
private[this] var transportContext: TransportContext = _
4747
private[this] var server: TransportServer = _
4848
private[this] var clientFactory: TransportClientFactory = _
49+
private[this] var appId: String = _
4950

5051
override def init(blockDataManager: BlockDataManager): Unit = {
5152
val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
@@ -60,6 +61,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
6061
transportContext = new TransportContext(transportConf, rpcHandler)
6162
clientFactory = transportContext.createClientFactory(bootstrap.toList)
6263
server = transportContext.createServer()
64+
appId = conf.getAppId
6365
logInfo("Server created on " + server.getPort)
6466
}
6567

@@ -74,8 +76,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
7476
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
7577
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
7678
val client = clientFactory.createClient(host, port)
77-
new OneForOneBlockFetcher(client, blockIds.toArray, listener)
78-
.start(OpenBlocks(blockIds.map(BlockId.apply)))
79+
new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
7980
}
8081
}
8182

@@ -101,12 +102,17 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
101102
override def uploadBlock(
102103
hostname: String,
103104
port: Int,
105+
execId: String,
104106
blockId: BlockId,
105107
blockData: ManagedBuffer,
106108
level: StorageLevel): Future[Unit] = {
107109
val result = Promise[Unit]()
108110
val client = clientFactory.createClient(hostname, port)
109111

112+
// StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
113+
// using our binary protocol.
114+
val levelBytes = serializer.newInstance().serialize(level).array()
115+
110116
// Convert or copy nio buffer into array in order to serialize it.
111117
val nioBuffer = blockData.nioByteBuffer()
112118
val array = if (nioBuffer.hasArray) {
@@ -117,8 +123,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
117123
data
118124
}
119125

120-
val ser = serializer.newInstance()
121-
client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(),
126+
client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray,
122127
new RpcResponseCallback {
123128
override def onSuccess(response: Array[Byte]): Unit = {
124129
logTrace(s"Successfully uploaded block $blockId")

core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
137137
override def uploadBlock(
138138
hostname: String,
139139
port: Int,
140+
execId: String,
140141
blockId: BlockId,
141142
blockData: ManagedBuffer,
142143
level: StorageLevel)

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ import org.apache.spark.io.CompressionCodec
3535
import org.apache.spark.network._
3636
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
3737
import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService}
38-
import org.apache.spark.network.shuffle.{ExecutorShuffleInfo, ExternalShuffleClient}
38+
import org.apache.spark.network.shuffle.ExternalShuffleClient
39+
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
3940
import org.apache.spark.network.util.{ConfigProvider, TransportConf}
4041
import org.apache.spark.serializer.Serializer
4142
import org.apache.spark.shuffle.ShuffleManager
@@ -939,7 +940,7 @@ private[spark] class BlockManager(
939940
data.rewind()
940941
logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
941942
blockTransferService.uploadBlockSync(
942-
peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel)
943+
peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel)
943944
logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
944945
.format(System.currentTimeMillis - onePeerStartTime))
945946
peersReplicatedTo += peer

core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMat
3636

3737
class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers {
3838
test("security default off") {
39-
testConnection(new SparkConf, new SparkConf) match {
39+
val conf = new SparkConf()
40+
.set("spark.app.id", "app-id")
41+
testConnection(conf, conf) match {
4042
case Success(_) => // expected
4143
case Failure(t) => fail(t)
4244
}

network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,19 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) {
3838

3939
@Override
4040
public int encodedLength() {
41-
return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length;
41+
return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString);
4242
}
4343

4444
@Override
4545
public void encode(ByteBuf buf) {
4646
streamChunkId.encode(buf);
47-
byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
48-
buf.writeInt(errorBytes.length);
49-
buf.writeBytes(errorBytes);
47+
Encoders.Strings.encode(buf, errorString);
5048
}
5149

5250
public static ChunkFetchFailure decode(ByteBuf buf) {
5351
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
54-
int numErrorStringBytes = buf.readInt();
55-
byte[] errorBytes = new byte[numErrorStringBytes];
56-
buf.readBytes(errorBytes);
57-
return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8));
52+
String errorString = Encoders.Strings.decode(buf);
53+
return new ChunkFetchFailure(streamChunkId, errorString);
5854
}
5955

6056
@Override
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.protocol;
19+
20+
21+
import com.google.common.base.Charsets;
22+
import io.netty.buffer.ByteBuf;
23+
import io.netty.buffer.Unpooled;
24+
25+
/** Provides a canonical set of Encoders for simple types. */
26+
public class Encoders {
27+
28+
/** Strings are encoded with their length followed by UTF-8 bytes. */
29+
public static class Strings {
30+
public static int encodedLength(String s) {
31+
return 4 + s.getBytes(Charsets.UTF_8).length;
32+
}
33+
34+
public static void encode(ByteBuf buf, String s) {
35+
byte[] bytes = s.getBytes(Charsets.UTF_8);
36+
buf.writeInt(bytes.length);
37+
buf.writeBytes(bytes);
38+
}
39+
40+
public static String decode(ByteBuf buf) {
41+
int length = buf.readInt();
42+
byte[] bytes = new byte[length];
43+
buf.readBytes(bytes);
44+
return new String(bytes, Charsets.UTF_8);
45+
}
46+
}
47+
48+
/** Byte arrays are encoded with their length followed by bytes. */
49+
public static class ByteArrays {
50+
public static int encodedLength(byte[] arr) {
51+
return 4 + arr.length;
52+
}
53+
54+
public static void encode(ByteBuf buf, byte[] arr) {
55+
buf.writeInt(arr.length);
56+
buf.writeBytes(arr);
57+
}
58+
59+
public static byte[] decode(ByteBuf buf) {
60+
int length = buf.readInt();
61+
byte[] bytes = new byte[length];
62+
buf.readBytes(bytes);
63+
return bytes;
64+
}
65+
}
66+
67+
/** String arrays are encoded with the number of strings followed by per-String encoding. */
68+
public static class StringArrays {
69+
public static int encodedLength(String[] strings) {
70+
int totalLength = 4;
71+
for (String s : strings) {
72+
totalLength += Strings.encodedLength(s);
73+
}
74+
return totalLength;
75+
}
76+
77+
public static void encode(ByteBuf buf, String[] strings) {
78+
buf.writeInt(strings.length);
79+
for (String s : strings) {
80+
Strings.encode(buf, s);
81+
}
82+
}
83+
84+
public static String[] decode(ByteBuf buf) {
85+
int numStrings = buf.readInt();
86+
String[] strings = new String[numStrings];
87+
for (int i = 0; i < strings.length; i ++) {
88+
strings[i] = Strings.decode(buf);
89+
}
90+
return strings;
91+
}
92+
}
93+
}

network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,19 @@ public RpcFailure(long requestId, String errorString) {
3636

3737
@Override
3838
public int encodedLength() {
39-
return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length;
39+
return 8 + Encoders.Strings.encodedLength(errorString);
4040
}
4141

4242
@Override
4343
public void encode(ByteBuf buf) {
4444
buf.writeLong(requestId);
45-
byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
46-
buf.writeInt(errorBytes.length);
47-
buf.writeBytes(errorBytes);
45+
Encoders.Strings.encode(buf, errorString);
4846
}
4947

5048
public static RpcFailure decode(ByteBuf buf) {
5149
long requestId = buf.readLong();
52-
int numErrorStringBytes = buf.readInt();
53-
byte[] errorBytes = new byte[numErrorStringBytes];
54-
buf.readBytes(errorBytes);
55-
return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8));
50+
String errorString = Encoders.Strings.decode(buf);
51+
return new RpcFailure(requestId, errorString);
5652
}
5753

5854
@Override

network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,18 @@ public RpcRequest(long requestId, byte[] message) {
4444

4545
@Override
4646
public int encodedLength() {
47-
return 8 + 4 + message.length;
47+
return 8 + Encoders.ByteArrays.encodedLength(message);
4848
}
4949

5050
@Override
5151
public void encode(ByteBuf buf) {
5252
buf.writeLong(requestId);
53-
buf.writeInt(message.length);
54-
buf.writeBytes(message);
53+
Encoders.ByteArrays.encode(buf, message);
5554
}
5655

5756
public static RpcRequest decode(ByteBuf buf) {
5857
long requestId = buf.readLong();
59-
int messageLen = buf.readInt();
60-
byte[] message = new byte[messageLen];
61-
buf.readBytes(message);
58+
byte[] message = Encoders.ByteArrays.decode(buf);
6259
return new RpcRequest(requestId, message);
6360
}
6461

0 commit comments

Comments
 (0)