Skip to content

Commit 37a4b6c

Browse files
committed
Remove streamIds from TransportRequestHandler.
1 parent 3b3f38a commit 37a4b6c

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
import java.util.Iterator;
2121
import java.util.Map;
2222
import java.util.Random;
23+
import java.util.Set;
2324
import java.util.concurrent.ConcurrentHashMap;
2425
import java.util.concurrent.atomic.AtomicLong;
2526

27+
import com.google.common.collect.Sets;
28+
import io.netty.channel.Channel;
2629
import org.slf4j.Logger;
2730
import org.slf4j.LoggerFactory;
2831

@@ -38,6 +41,9 @@ public class OneForOneStreamManager extends StreamManager {
3841
private final AtomicLong nextStreamId;
3942
private final Map<Long, StreamState> streams;
4043

44+
/** List of all stream ids that are associated to specified channel. **/
45+
private final Map<Channel, Set<Long>> streamIds;
46+
4147
/** State of a single stream. */
4248
private static class StreamState {
4349
final Iterator<ManagedBuffer> buffers;
@@ -56,11 +62,15 @@ public OneForOneStreamManager() {
5662
// This does not need to be globally unique, only unique to this class.
5763
nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
5864
streams = new ConcurrentHashMap<Long, StreamState>();
65+
streamIds = new ConcurrentHashMap<Channel, Set<Long>>();
5966
}
6067

6168
@Override
62-
public boolean streamHasNext(long streamId) {
63-
return streams.containsKey(streamId);
69+
public void registerChannel(Channel channel, long streamId) {
70+
if (!streamIds.containsKey(channel)) {
71+
streamIds.put(channel, Sets.newHashSet());
72+
}
73+
streamIds.get(channel).add(streamId);
6474
}
6575

6676
@Override
@@ -84,6 +94,17 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
8494
return nextChunk;
8595
}
8696

97+
@Override
98+
public void connectionTerminated(Channel channel) {
99+
// Release all associated streams
100+
if (streamIds.containsKey(channel)) {
101+
for (long streamId : streamIds.get(channel)) {
102+
connectionTerminated(streamId);
103+
}
104+
streamIds.remove(channel);
105+
}
106+
}
107+
87108
@Override
88109
public void connectionTerminated(long streamId) {
89110
// Release all remaining buffers.

network/common/src/main/java/org/apache/spark/network/server/StreamManager.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.network.server;
1919

20+
import io.netty.channel.Channel;
21+
2022
import org.apache.spark.network.buffer.ManagedBuffer;
2123

2224
/**
@@ -44,9 +46,15 @@ public abstract class StreamManager {
4446
public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
4547

4648
/**
47-
* Indicates that if the specified stream has next chunks to read further.
49+
* Register the given stream to the associated channel. So these streams can be cleaned up later.
50+
*/
51+
public void registerChannel(Channel channel, long streamId) { }
52+
53+
/**
54+
* Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
55+
* to read from the associated streams again, so any state can be cleaned up.
4856
*/
49-
public boolean streamHasNext(long streamId) { return true; }
57+
public void connectionTerminated(Channel channel) { }
5058

5159
/**
5260
* Indicates that the TCP connection that was tied to the given stream has been terminated. After

network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
6262
/** Returns each chunk part of a stream. */
6363
private final StreamManager streamManager;
6464

65-
/** List of all stream ids that have been read on this handler, used for cleanup. */
66-
private final Set<Long> streamIds;
67-
6865
public TransportRequestHandler(
6966
Channel channel,
7067
TransportClient reverseClient,
@@ -73,7 +70,6 @@ public TransportRequestHandler(
7370
this.reverseClient = reverseClient;
7471
this.rpcHandler = rpcHandler;
7572
this.streamManager = rpcHandler.getStreamManager();
76-
this.streamIds = Sets.newHashSet();
7773
}
7874

7975
@Override
@@ -82,10 +78,8 @@ public void exceptionCaught(Throwable cause) {
8278

8379
@Override
8480
public void channelUnregistered() {
85-
// Inform the StreamManager that these streams will no longer be read from.
86-
for (long streamId : streamIds) {
87-
streamManager.connectionTerminated(streamId);
88-
}
81+
// Inform the StreamManager that this channel is unregistered.
82+
streamManager.connectionTerminated(channel);
8983
rpcHandler.connectionTerminated(reverseClient);
9084
}
9185

@@ -102,16 +96,14 @@ public void handle(RequestMessage request) {
10296

10397
private void processFetchRequest(final ChunkFetchRequest req) {
10498
final String client = NettyUtils.getRemoteAddress(channel);
105-
streamIds.add(req.streamChunkId.streamId);
99+
100+
streamManager.registerChannel(channel, req.streamChunkId.streamId);
106101

107102
logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);
108103

109104
ManagedBuffer buf;
110105
try {
111106
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
112-
if (!streamManager.streamHasNext(req.streamChunkId.streamId)) {
113-
streamIds.remove(req.streamChunkId.streamId);
114-
}
115107
} catch (Exception e) {
116108
logger.error(String.format(
117109
"Error opening block %s for request from %s", req.streamChunkId, client), e);

0 commit comments

Comments
 (0)