26
26
27
27
import com .google .common .annotations .VisibleForTesting ;
28
28
import io .netty .channel .Channel ;
29
+ import org .apache .commons .lang3 .tuple .ImmutablePair ;
30
+ import org .apache .commons .lang3 .tuple .Pair ;
29
31
import org .slf4j .Logger ;
30
32
import org .slf4j .LoggerFactory ;
31
33
@@ -56,7 +58,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
56
58
57
59
private final Map <Long , RpcResponseCallback > outstandingRpcs ;
58
60
59
- private final Queue <StreamCallback > streamCallbacks ;
61
+ private final Queue <Pair < String , StreamCallback > > streamCallbacks ;
60
62
private volatile boolean streamActive ;
61
63
62
64
/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
@@ -88,9 +90,9 @@ public void removeRpcRequest(long requestId) {
88
90
outstandingRpcs .remove (requestId );
89
91
}
90
92
91
- public void addStreamCallback (StreamCallback callback ) {
93
+ public void addStreamCallback (String streamId , StreamCallback callback ) {
92
94
timeOfLastRequestNs .set (System .nanoTime ());
93
- streamCallbacks .offer (callback );
95
+ streamCallbacks .offer (ImmutablePair . of ( streamId , callback ) );
94
96
}
95
97
96
98
@ VisibleForTesting
@@ -104,15 +106,31 @@ public void deactivateStream() {
104
106
*/
105
107
private void failOutstandingRequests (Throwable cause ) {
106
108
for (Map .Entry <StreamChunkId , ChunkReceivedCallback > entry : outstandingFetches .entrySet ()) {
107
- entry .getValue ().onFailure (entry .getKey ().chunkIndex , cause );
109
+ try {
110
+ entry .getValue ().onFailure (entry .getKey ().chunkIndex , cause );
111
+ } catch (Exception e ) {
112
+ logger .warn ("ChunkReceivedCallback.onFailure throws exception" , e );
113
+ }
108
114
}
109
115
for (Map .Entry <Long , RpcResponseCallback > entry : outstandingRpcs .entrySet ()) {
110
- entry .getValue ().onFailure (cause );
116
+ try {
117
+ entry .getValue ().onFailure (cause );
118
+ } catch (Exception e ) {
119
+ logger .warn ("RpcResponseCallback.onFailure throws exception" , e );
120
+ }
121
+ }
122
+ for (Pair <String , StreamCallback > entry : streamCallbacks ) {
123
+ try {
124
+ entry .getValue ().onFailure (entry .getKey (), cause );
125
+ } catch (Exception e ) {
126
+ logger .warn ("StreamCallback.onFailure throws exception" , e );
127
+ }
111
128
}
112
129
113
130
// It's OK if new fetches appear, as they will fail immediately.
114
131
outstandingFetches .clear ();
115
132
outstandingRpcs .clear ();
133
+ streamCallbacks .clear ();
116
134
}
117
135
118
136
@ Override
@@ -190,8 +208,9 @@ public void handle(ResponseMessage message) throws Exception {
190
208
}
191
209
} else if (message instanceof StreamResponse ) {
192
210
StreamResponse resp = (StreamResponse ) message ;
193
- StreamCallback callback = streamCallbacks .poll ();
194
- if (callback != null ) {
211
+ Pair <String , StreamCallback > entry = streamCallbacks .poll ();
212
+ if (entry != null ) {
213
+ StreamCallback callback = entry .getValue ();
195
214
if (resp .byteCount > 0 ) {
196
215
StreamInterceptor interceptor = new StreamInterceptor (this , resp .streamId , resp .byteCount ,
197
216
callback );
@@ -216,8 +235,9 @@ public void handle(ResponseMessage message) throws Exception {
216
235
}
217
236
} else if (message instanceof StreamFailure ) {
218
237
StreamFailure resp = (StreamFailure ) message ;
219
- StreamCallback callback = streamCallbacks .poll ();
220
- if (callback != null ) {
238
+ Pair <String , StreamCallback > entry = streamCallbacks .poll ();
239
+ if (entry != null ) {
240
+ StreamCallback callback = entry .getValue ();
221
241
try {
222
242
callback .onFailure (resp .streamId , new RuntimeException (resp .error ));
223
243
} catch (IOException ioe ) {
0 commit comments