Skip to content

Commit f084c5d

Browse files
Marcelo Vanzinrxin
authored andcommitted
[SPARK-6578] [core] Fix thread-safety issue in outbound path of network library.
While the inbound path of a netty pipeline is thread-safe, the outbound path is not. That means that multiple threads can compete to write messages to the next stage of the pipeline. The network library sometimes breaks a single RPC message into multiple buffers internally to avoid copying data (see MessageEncoder). This can result in the following scenario (where "FxBy" means "frame x, buffer y"): T1 F1B1 F1B2 \ \ \ \ socket F1B1 F2B1 F1B2 F2B2 / / / / T2 F2B1 F2B2 And the frames now cannot be rebuilt on the receiving side because the different messages have been mixed up on the wire. The fix wraps these multi-buffer messages into a `FileRegion` object so that these messages are written "atomically" to the next pipeline handler. Author: Marcelo Vanzin <[email protected]> Closes #5234 from vanzin/SPARK-6578 and squashes the following commits: 16b2d70 [Marcelo Vanzin] Forgot to update a type. c9c2e4e [Marcelo Vanzin] Review comments: simplify some code. 9c888ac [Marcelo Vanzin] Small style nits. 8474bab [Marcelo Vanzin] Fix multiple calls to MessageWithHeader.transferTo(). e26509f [Marcelo Vanzin] Merge branch 'master' into SPARK-6578 c503f6c [Marcelo Vanzin] Implement a custom FileRegion instead of using locks. 84aa7ce [Marcelo Vanzin] Rename handler to the correct name. 432f3bd [Marcelo Vanzin] Remove unneeded method. 8d70e60 [Marcelo Vanzin] Fix thread-safety issue in outbound path of network library.
1 parent fb25e8c commit f084c5d

File tree

7 files changed

+364
-10
lines changed

7 files changed

+364
-10
lines changed

network/common/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@
8080
<artifactId>mockito-all</artifactId>
8181
<scope>test</scope>
8282
</dependency>
83+
<dependency>
84+
<groupId>org.slf4j</groupId>
85+
<artifactId>slf4j-log4j12</artifactId>
86+
<scope>test</scope>
87+
</dependency>
8388
</dependencies>
8489

8590
<build>

network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,11 @@ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
7272
in.encode(header);
7373
assert header.writableBytes() == 0;
7474

75-
out.add(header);
7675
if (body != null && bodyLength > 0) {
77-
out.add(body);
76+
out.add(new MessageWithHeader(header, body, bodyLength));
77+
} else {
78+
out.add(header);
7879
}
7980
}
81+
8082
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.protocol;
19+
20+
import java.io.IOException;
21+
import java.nio.channels.WritableByteChannel;
22+
23+
import com.google.common.base.Preconditions;
24+
import com.google.common.primitives.Ints;
25+
import io.netty.buffer.ByteBuf;
26+
import io.netty.channel.FileRegion;
27+
import io.netty.util.AbstractReferenceCounted;
28+
import io.netty.util.ReferenceCountUtil;
29+
30+
/**
31+
* A wrapper message that holds two separate pieces (a header and a body) to avoid
32+
* copying the body's content.
33+
*/
34+
class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
35+
36+
private final ByteBuf header;
37+
private final int headerLength;
38+
private final Object body;
39+
private final long bodyLength;
40+
private long totalBytesTransferred;
41+
42+
MessageWithHeader(ByteBuf header, Object body, long bodyLength) {
43+
Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion,
44+
"Body must be a ByteBuf or a FileRegion.");
45+
this.header = header;
46+
this.headerLength = header.readableBytes();
47+
this.body = body;
48+
this.bodyLength = bodyLength;
49+
}
50+
51+
@Override
52+
public long count() {
53+
return headerLength + bodyLength;
54+
}
55+
56+
@Override
57+
public long position() {
58+
return 0;
59+
}
60+
61+
@Override
62+
public long transfered() {
63+
return totalBytesTransferred;
64+
}
65+
66+
@Override
67+
public long transferTo(WritableByteChannel target, long position) throws IOException {
68+
Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position.");
69+
long written = 0;
70+
71+
if (position < headerLength) {
72+
written += copyByteBuf(header, target);
73+
if (header.readableBytes() > 0) {
74+
totalBytesTransferred += written;
75+
return written;
76+
}
77+
}
78+
79+
if (body instanceof FileRegion) {
80+
// Adjust the position. If the write is happening as part of the same call where the header
81+
// (or some part of it) is written, `position` will be less than the header size, so we want
82+
// to start from position 0 in the FileRegion object. Otherwise, we start from the position
83+
// requested by the caller.
84+
long bodyPos = position > headerLength ? position - headerLength : 0;
85+
written += ((FileRegion)body).transferTo(target, bodyPos);
86+
} else if (body instanceof ByteBuf) {
87+
written += copyByteBuf((ByteBuf) body, target);
88+
}
89+
90+
totalBytesTransferred += written;
91+
return written;
92+
}
93+
94+
@Override
95+
protected void deallocate() {
96+
header.release();
97+
ReferenceCountUtil.release(body);
98+
}
99+
100+
private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException {
101+
int written = target.write(buf.nioBuffer());
102+
buf.skipBytes(written);
103+
return written;
104+
}
105+
106+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network;
19+
20+
import java.nio.ByteBuffer;
21+
import java.nio.channels.WritableByteChannel;
22+
23+
public class ByteArrayWritableChannel implements WritableByteChannel {
24+
25+
private final byte[] data;
26+
private int offset;
27+
28+
public ByteArrayWritableChannel(int size) {
29+
this.data = new byte[size];
30+
this.offset = 0;
31+
}
32+
33+
public byte[] getData() {
34+
return data;
35+
}
36+
37+
@Override
38+
public int write(ByteBuffer src) {
39+
int available = src.remaining();
40+
src.get(data, offset, available);
41+
offset += available;
42+
return available;
43+
}
44+
45+
@Override
46+
public void close() {
47+
48+
}
49+
50+
@Override
51+
public boolean isOpen() {
52+
return true;
53+
}
54+
55+
}

network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,34 @@
1717

1818
package org.apache.spark.network;
1919

20+
import java.util.List;
21+
22+
import com.google.common.primitives.Ints;
23+
import io.netty.buffer.Unpooled;
24+
import io.netty.channel.ChannelHandlerContext;
25+
import io.netty.channel.FileRegion;
2026
import io.netty.channel.embedded.EmbeddedChannel;
27+
import io.netty.handler.codec.MessageToMessageEncoder;
2128
import org.junit.Test;
2229

2330
import static org.junit.Assert.assertEquals;
2431

25-
import org.apache.spark.network.protocol.Message;
26-
import org.apache.spark.network.protocol.StreamChunkId;
27-
import org.apache.spark.network.protocol.ChunkFetchRequest;
2832
import org.apache.spark.network.protocol.ChunkFetchFailure;
33+
import org.apache.spark.network.protocol.ChunkFetchRequest;
2934
import org.apache.spark.network.protocol.ChunkFetchSuccess;
30-
import org.apache.spark.network.protocol.RpcRequest;
31-
import org.apache.spark.network.protocol.RpcFailure;
32-
import org.apache.spark.network.protocol.RpcResponse;
35+
import org.apache.spark.network.protocol.Message;
3336
import org.apache.spark.network.protocol.MessageDecoder;
3437
import org.apache.spark.network.protocol.MessageEncoder;
38+
import org.apache.spark.network.protocol.RpcFailure;
39+
import org.apache.spark.network.protocol.RpcRequest;
40+
import org.apache.spark.network.protocol.RpcResponse;
41+
import org.apache.spark.network.protocol.StreamChunkId;
3542
import org.apache.spark.network.util.NettyUtils;
3643

3744
public class ProtocolSuite {
3845
private void testServerToClient(Message msg) {
39-
EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder());
46+
EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
47+
new MessageEncoder());
4048
serverChannel.writeOutbound(msg);
4149

4250
EmbeddedChannel clientChannel = new EmbeddedChannel(
@@ -51,7 +59,8 @@ private void testServerToClient(Message msg) {
5159
}
5260

5361
private void testClientToServer(Message msg) {
54-
EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder());
62+
EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
63+
new MessageEncoder());
5564
clientChannel.writeOutbound(msg);
5665

5766
EmbeddedChannel serverChannel = new EmbeddedChannel(
@@ -83,4 +92,25 @@ public void responses() {
8392
testServerToClient(new RpcFailure(0, "this is an error"));
8493
testServerToClient(new RpcFailure(0, ""));
8594
}
95+
96+
/**
97+
* Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer
98+
* bytes, but messages, so this is needed so that the frame decoder on the receiving side can
99+
* understand what MessageWithHeader actually contains.
100+
*/
101+
private static class FileRegionEncoder extends MessageToMessageEncoder<FileRegion> {
102+
103+
@Override
104+
public void encode(ChannelHandlerContext ctx, FileRegion in, List<Object> out)
105+
throws Exception {
106+
107+
ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count()));
108+
while (in.transfered() < in.count()) {
109+
in.transferTo(channel, in.transfered());
110+
}
111+
out.add(Unpooled.wrappedBuffer(channel.getData()));
112+
}
113+
114+
}
115+
86116
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.protocol;
19+
20+
import java.io.IOException;
21+
import java.nio.ByteBuffer;
22+
import java.nio.channels.WritableByteChannel;
23+
24+
import io.netty.buffer.ByteBuf;
25+
import io.netty.buffer.Unpooled;
26+
import io.netty.channel.FileRegion;
27+
import io.netty.util.AbstractReferenceCounted;
28+
import org.junit.Test;
29+
30+
import static org.junit.Assert.*;
31+
32+
import org.apache.spark.network.ByteArrayWritableChannel;
33+
34+
public class MessageWithHeaderSuite {
35+
36+
@Test
37+
public void testSingleWrite() throws Exception {
38+
testFileRegionBody(8, 8);
39+
}
40+
41+
@Test
42+
public void testShortWrite() throws Exception {
43+
testFileRegionBody(8, 1);
44+
}
45+
46+
@Test
47+
public void testByteBufBody() throws Exception {
48+
ByteBuf header = Unpooled.copyLong(42);
49+
ByteBuf body = Unpooled.copyLong(84);
50+
MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes());
51+
52+
ByteBuf result = doWrite(msg, 1);
53+
assertEquals(msg.count(), result.readableBytes());
54+
assertEquals(42, result.readLong());
55+
assertEquals(84, result.readLong());
56+
}
57+
58+
private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
59+
ByteBuf header = Unpooled.copyLong(42);
60+
int headerLength = header.readableBytes();
61+
TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
62+
MessageWithHeader msg = new MessageWithHeader(header, region, region.count());
63+
64+
ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
65+
assertEquals(headerLength + region.count(), result.readableBytes());
66+
assertEquals(42, result.readLong());
67+
for (long i = 0; i < 8; i++) {
68+
assertEquals(i, result.readLong());
69+
}
70+
}
71+
72+
private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {
73+
int writes = 0;
74+
ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count());
75+
while (msg.transfered() < msg.count()) {
76+
msg.transferTo(channel, msg.transfered());
77+
writes++;
78+
}
79+
assertTrue("Not enough writes!", minExpectedWrites <= writes);
80+
return Unpooled.wrappedBuffer(channel.getData());
81+
}
82+
83+
private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion {
84+
85+
private final int writeCount;
86+
private final int writesPerCall;
87+
private int written;
88+
89+
TestFileRegion(int totalWrites, int writesPerCall) {
90+
this.writeCount = totalWrites;
91+
this.writesPerCall = writesPerCall;
92+
}
93+
94+
@Override
95+
public long count() {
96+
return 8 * writeCount;
97+
}
98+
99+
@Override
100+
public long position() {
101+
return 0;
102+
}
103+
104+
@Override
105+
public long transfered() {
106+
return 8 * written;
107+
}
108+
109+
@Override
110+
public long transferTo(WritableByteChannel target, long position) throws IOException {
111+
for (int i = 0; i < writesPerCall; i++) {
112+
ByteBuf buf = Unpooled.copyLong((position / 8) + i);
113+
ByteBuffer nio = buf.nioBuffer();
114+
while (nio.remaining() > 0) {
115+
target.write(nio);
116+
}
117+
buf.release();
118+
written++;
119+
}
120+
return 8 * writesPerCall;
121+
}
122+
123+
@Override
124+
protected void deallocate() {
125+
}
126+
127+
}
128+
129+
}

0 commit comments

Comments
 (0)