20
20
import java .util .Iterator ;
21
21
import java .util .Map ;
22
22
import java .util .Random ;
23
+ import java .util .Set ;
23
24
import java .util .concurrent .ConcurrentHashMap ;
24
25
import java .util .concurrent .atomic .AtomicLong ;
25
26
27
+ import io .netty .channel .Channel ;
26
28
import org .slf4j .Logger ;
27
29
import org .slf4j .LoggerFactory ;
28
30
29
31
import org .apache .spark .network .buffer .ManagedBuffer ;
30
32
33
+ import com .google .common .base .Preconditions ;
34
+
31
35
/**
32
36
* StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually
33
37
* fetched as chunks by the client. Each registered buffer is one chunk.
@@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager {
36
40
private final Logger logger = LoggerFactory .getLogger (OneForOneStreamManager .class );
37
41
38
42
private final AtomicLong nextStreamId ;
39
- private final Map <Long , StreamState > streams ;
43
+ private final ConcurrentHashMap <Long , StreamState > streams ;
40
44
41
45
/** State of a single stream. */
42
46
private static class StreamState {
43
47
final Iterator <ManagedBuffer > buffers ;
44
48
49
+ // The channel associated to the stream
50
+ Channel associatedChannel = null ;
51
+
45
52
// Used to keep track of the index of the buffer that the user has retrieved, just to ensure
46
53
// that the caller only requests each chunk one at a time, in order.
47
54
int curChunk = 0 ;
48
55
49
56
StreamState (Iterator <ManagedBuffer > buffers ) {
50
- this .buffers = buffers ;
57
+ this .buffers = Preconditions . checkNotNull ( buffers ) ;
51
58
}
52
59
}
53
60
@@ -58,6 +65,13 @@ public OneForOneStreamManager() {
58
65
streams = new ConcurrentHashMap <Long , StreamState >();
59
66
}
60
67
68
+ @ Override
69
+ public void registerChannel (Channel channel , long streamId ) {
70
+ if (streams .containsKey (streamId )) {
71
+ streams .get (streamId ).associatedChannel = channel ;
72
+ }
73
+ }
74
+
61
75
@ Override
62
76
public ManagedBuffer getChunk (long streamId , int chunkIndex ) {
63
77
StreamState state = streams .get (streamId );
@@ -80,12 +94,17 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
80
94
}
81
95
82
96
@ Override
83
- public void connectionTerminated (long streamId ) {
84
- // Release all remaining buffers.
85
- StreamState state = streams .remove (streamId );
86
- if (state != null && state .buffers != null ) {
87
- while (state .buffers .hasNext ()) {
88
- state .buffers .next ().release ();
97
+ public void connectionTerminated (Channel channel ) {
98
+ // Close all streams which have been associated with the channel.
99
+ for (Map .Entry <Long , StreamState > entry : streams .entrySet ()) {
100
+ StreamState state = entry .getValue ();
101
+ if (state .associatedChannel == channel ) {
102
+ streams .remove (entry .getKey ());
103
+
104
+ // Release all remaining buffers.
105
+ while (state .buffers .hasNext ()) {
106
+ state .buffers .next ().release ();
107
+ }
89
108
}
90
109
}
91
110
}
0 commit comments