diff --git a/driver-core/src/main/com/mongodb/internal/connection/AsynchronousChannelStream.java b/driver-core/src/main/com/mongodb/internal/connection/AsynchronousChannelStream.java index bbb18497ee..89396dae5d 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/AsynchronousChannelStream.java +++ b/driver-core/src/main/com/mongodb/internal/connection/AsynchronousChannelStream.java @@ -37,6 +37,7 @@ import java.util.concurrent.atomic.AtomicReference; import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.thread.InterruptionUtil.interruptAndCreateMongoInterruptedException; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -88,7 +89,7 @@ protected void setChannel(final ExtendedAsynchronousByteChannel channel) { @Override public void writeAsync(final List buffers, final OperationContext operationContext, - final AsyncCompletionHandler handler) { + final AsyncCompletionHandler handler) { AsyncWritableByteChannelAdapter byteChannel = new AsyncWritableByteChannelAdapter(); Iterator iter = buffers.iterator(); pipeOneBuffer(byteChannel, iter.next(), operationContext, new AsyncCompletionHandler() { @@ -189,8 +190,11 @@ public void failed(final Throwable t) { private class AsyncWritableByteChannelAdapter { void write(final ByteBuffer src, final OperationContext operationContext, final AsyncCompletionHandler handler) { - getChannel().write(src, operationContext.getTimeoutContext().getWriteTimeoutMS(), MILLISECONDS, null, - new AsyncWritableByteChannelAdapter.WriteCompletionHandler(handler)); + beginAsync().thenRun((c) -> { + long writeTimeoutMS = operationContext.getTimeoutContext().getWriteTimeoutMS(); + getChannel().write(src, writeTimeoutMS, MILLISECONDS, null, + new AsyncWritableByteChannelAdapter.WriteCompletionHandler(c.asHandler())); + }).finish(handler.asCallback()); } private class WriteCompletionHandler extends BaseCompletionHandler { @@ -222,7 +226,7 @@ private final class BasicCompletionHandler extends BaseCompletionHandler handler) { + final AsyncCompletionHandler handler) { super(handler); this.byteBufReference = new AtomicReference<>(dst); this.operationContext = operationContext; @@ -231,17 +235,20 @@ private BasicCompletionHandler(final ByteBuf dst, final OperationContext operati @Override public void completed(final Integer result, final Void attachment) { AsyncCompletionHandler localHandler = getHandlerAndClear(); - ByteBuf localByteBuf = byteBufReference.getAndSet(null); - if (result == -1) { - localByteBuf.release(); - localHandler.failed(new MongoSocketReadException("Prematurely reached end of stream", serverAddress)); - } else if (!localByteBuf.hasRemaining()) { - localByteBuf.flip(); - localHandler.completed(localByteBuf); - } else { - getChannel().read(localByteBuf.asNIO(), operationContext.getTimeoutContext().getReadTimeoutMS(), MILLISECONDS, null, - new BasicCompletionHandler(localByteBuf, operationContext, localHandler)); - } + beginAsync().thenSupply((c) -> { + ByteBuf localByteBuf = byteBufReference.getAndSet(null); + if (result == -1) { + localByteBuf.release(); + throw new MongoSocketReadException("Prematurely reached end of stream", serverAddress); + } else if (!localByteBuf.hasRemaining()) { + localByteBuf.flip(); + c.complete(localByteBuf); + } else { + long readTimeoutMS = operationContext.getTimeoutContext().getReadTimeoutMS(); + getChannel().read(localByteBuf.asNIO(), readTimeoutMS, MILLISECONDS, null, + new BasicCompletionHandler(localByteBuf, operationContext, c.asHandler())); + } + }).finish(localHandler.asCallback()); } @Override