Skip to content

Commit 1686032

Browse files
viiryaaarondav
authored andcommitted
[SPARK-7183] [NETWORK] Fix memory leak of TransportRequestHandler.streamIds
JIRA: https://issues.apache.org/jira/browse/SPARK-7183 Author: Liang-Chi Hsieh <[email protected]> Closes #5743 from viirya/fix_requesthandler_memory_leak and squashes the following commits: cf2c086 [Liang-Chi Hsieh] For comments. 97e205c [Liang-Chi Hsieh] Remove unused import. d35f19a [Liang-Chi Hsieh] For comments. f9a0c37 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into fix_requesthandler_memory_leak 45908b7 [Liang-Chi Hsieh] for style. 17f020f [Liang-Chi Hsieh] Remove unused import. 37a4b6c [Liang-Chi Hsieh] Remove streamIds from TransportRequestHandler. 3b3f38a [Liang-Chi Hsieh] Fix memory leak of TransportRequestHandler.streamIds.
1 parent 1262e31 commit 1686032

File tree

3 files changed

+44
-24
lines changed

3 files changed

+44
-24
lines changed

network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,18 @@
2020
import java.util.Iterator;
2121
import java.util.Map;
2222
import java.util.Random;
23+
import java.util.Set;
2324
import java.util.concurrent.ConcurrentHashMap;
2425
import java.util.concurrent.atomic.AtomicLong;
2526

27+
import io.netty.channel.Channel;
2628
import org.slf4j.Logger;
2729
import org.slf4j.LoggerFactory;
2830

2931
import org.apache.spark.network.buffer.ManagedBuffer;
3032

33+
import com.google.common.base.Preconditions;
34+
3135
/**
3236
* StreamManager which allows registration of an Iterator&lt;ManagedBuffer&gt;, which are individually
3337
* fetched as chunks by the client. Each registered buffer is one chunk.
@@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager {
3640
private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class);
3741

3842
private final AtomicLong nextStreamId;
39-
private final Map<Long, StreamState> streams;
43+
private final ConcurrentHashMap<Long, StreamState> streams;
4044

4145
/** State of a single stream. */
4246
private static class StreamState {
4347
final Iterator<ManagedBuffer> buffers;
4448

49+
// The channel associated to the stream
50+
Channel associatedChannel = null;
51+
4552
// Used to keep track of the index of the buffer that the user has retrieved, just to ensure
4653
// that the caller only requests each chunk one at a time, in order.
4754
int curChunk = 0;
4855

4956
StreamState(Iterator<ManagedBuffer> buffers) {
50-
this.buffers = buffers;
57+
this.buffers = Preconditions.checkNotNull(buffers);
5158
}
5259
}
5360

@@ -58,6 +65,13 @@ public OneForOneStreamManager() {
5865
streams = new ConcurrentHashMap<Long, StreamState>();
5966
}
6067

68+
@Override
69+
public void registerChannel(Channel channel, long streamId) {
70+
if (streams.containsKey(streamId)) {
71+
streams.get(streamId).associatedChannel = channel;
72+
}
73+
}
74+
6175
@Override
6276
public ManagedBuffer getChunk(long streamId, int chunkIndex) {
6377
StreamState state = streams.get(streamId);
@@ -80,12 +94,17 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
8094
}
8195

8296
@Override
83-
public void connectionTerminated(long streamId) {
84-
// Release all remaining buffers.
85-
StreamState state = streams.remove(streamId);
86-
if (state != null && state.buffers != null) {
87-
while (state.buffers.hasNext()) {
88-
state.buffers.next().release();
97+
public void connectionTerminated(Channel channel) {
98+
// Close all streams which have been associated with the channel.
99+
for (Map.Entry<Long, StreamState> entry: streams.entrySet()) {
100+
StreamState state = entry.getValue();
101+
if (state.associatedChannel == channel) {
102+
streams.remove(entry.getKey());
103+
104+
// Release all remaining buffers.
105+
while (state.buffers.hasNext()) {
106+
state.buffers.next().release();
107+
}
89108
}
90109
}
91110
}

network/common/src/main/java/org/apache/spark/network/server/StreamManager.java

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.network.server;
1919

20+
import io.netty.channel.Channel;
21+
2022
import org.apache.spark.network.buffer.ManagedBuffer;
2123

2224
/**
@@ -44,9 +46,18 @@ public abstract class StreamManager {
4446
public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
4547

4648
/**
47-
* Indicates that the TCP connection that was tied to the given stream has been terminated. After
48-
* this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned
49-
* up.
49+
* Associates a stream with a single client connection, which is guaranteed to be the only reader
50+
* of the stream. The getChunk() method will be called serially on this connection and once the
51+
* connection is closed, the stream will never be used again, enabling cleanup.
52+
*
53+
* This must be called before the first getChunk() on the stream, but it may be invoked multiple
54+
* times with the same channel and stream id.
55+
*/
56+
public void registerChannel(Channel channel, long streamId) { }
57+
58+
/**
59+
* Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
60+
* to read from the associated streams again, so any state can be cleaned up.
5061
*/
51-
public void connectionTerminated(long streamId) { }
62+
public void connectionTerminated(Channel channel) { }
5263
}

network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@
1717

1818
package org.apache.spark.network.server;
1919

20-
import java.util.Set;
21-
2220
import com.google.common.base.Throwables;
23-
import com.google.common.collect.Sets;
2421
import io.netty.channel.Channel;
2522
import io.netty.channel.ChannelFuture;
2623
import io.netty.channel.ChannelFutureListener;
@@ -62,9 +59,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
6259
/** Returns each chunk part of a stream. */
6360
private final StreamManager streamManager;
6461

65-
/** List of all stream ids that have been read on this handler, used for cleanup. */
66-
private final Set<Long> streamIds;
67-
6862
public TransportRequestHandler(
6963
Channel channel,
7064
TransportClient reverseClient,
@@ -73,7 +67,6 @@ public TransportRequestHandler(
7367
this.reverseClient = reverseClient;
7468
this.rpcHandler = rpcHandler;
7569
this.streamManager = rpcHandler.getStreamManager();
76-
this.streamIds = Sets.newHashSet();
7770
}
7871

7972
@Override
@@ -82,10 +75,7 @@ public void exceptionCaught(Throwable cause) {
8275

8376
@Override
8477
public void channelUnregistered() {
85-
// Inform the StreamManager that these streams will no longer be read from.
86-
for (long streamId : streamIds) {
87-
streamManager.connectionTerminated(streamId);
88-
}
78+
streamManager.connectionTerminated(channel);
8979
rpcHandler.connectionTerminated(reverseClient);
9080
}
9181

@@ -102,12 +92,12 @@ public void handle(RequestMessage request) {
10292

10393
private void processFetchRequest(final ChunkFetchRequest req) {
10494
final String client = NettyUtils.getRemoteAddress(channel);
105-
streamIds.add(req.streamChunkId.streamId);
10695

10796
logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);
10897

10998
ManagedBuffer buf;
11099
try {
100+
streamManager.registerChannel(channel, req.streamChunkId.streamId);
111101
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
112102
} catch (Exception e) {
113103
logger.error(String.format(

0 commit comments

Comments
 (0)