Skip to content

Commit 799594d

Browse files
authored
Merge pull request #5898 from cloudflare/dominik/reland-socket-closure
Reland "Forcibly shutdown AsyncIoStream on Socket::close."
2 parents a005053 + 570fe91 commit 799594d

File tree

4 files changed

+46
-12
lines changed

4 files changed

+46
-12
lines changed

src/workerd/api/sockets.c++

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ jsg::Ref<Socket> setupSocket(jsg::Lock& js,
156156

157157
auto refcountedConnection = kj::refcountedWrapper(kj::mv(connection));
158158
// Initialize the readable/writable streams with the readable/writable sides of an AsyncIoStream.
159-
auto sysStreams = newSystemMultiStream(refcountedConnection->addWrappedRef(), ioContext);
159+
auto sysStreams = newSystemMultiStream(*refcountedConnection, ioContext);
160160
auto readable = js.alloc<ReadableStream>(ioContext, kj::mv(sysStreams.readable));
161161
auto allowHalfOpen = getAllowHalfOpen(options);
162162
kj::Maybe<jsg::Promise<void>> eofPromise;
@@ -301,12 +301,23 @@ jsg::Promise<void> Socket::close(jsg::Lock& js) {
301301
// Forcibly abort the readable/writable streams.
302302
auto cancelPromise = readable->getController().cancel(js, kj::none);
303303
auto abortPromise = writable->getController().abort(js, kj::none);
304+
304305
// The below is effectively `Promise.all(cancelPromise, abortPromise)`
305306
return cancelPromise.then(js, [abortPromise = kj::mv(abortPromise)](jsg::Lock& js) mutable {
306307
return kj::mv(abortPromise);
307308
});
308309
})
309310
.then(js, [this](jsg::Lock& js) {
311+
// This task needs to destroyed prior to destroying the AsyncIoStream as it is awaiting
312+
// that stream's `whenWriteDisconnected` promise.
313+
watchForDisconnectTask = nullptr;
314+
315+
// Destroy the tlsStarter which is also keeping the connection open.
316+
{ auto _ = kj::mv(tlsStarter); }
317+
318+
// Destroy the connection stream to close the connection.
319+
connectionStream = kj::none;
320+
310321
resolveFulfiller(js, kj::none);
311322
return js.resolvedPromise();
312323
}).catch_(js, [this](jsg::Lock& js, jsg::Value err) { errorHandler(js, kj::mv(err)); });
@@ -347,7 +358,7 @@ jsg::Ref<Socket> Socket::startTls(jsg::Lock& js, jsg::Optional<TlsOptions> tlsOp
347358
auto& context = IoContext::current();
348359

349360
self->writable->detach(js);
350-
self->readable = self->readable->detach(js, true);
361+
self->readable->detach(js, true);
351362

352363
// We should set this before closedResolver.resolve() in order to give the user
353364
// the option to check if the closed promise is resolved due to upgrade or not.
@@ -387,9 +398,18 @@ jsg::Ref<Socket> Socket::startTls(jsg::Lock& js, jsg::Optional<TlsOptions> tlsOp
387398
};
388399
}));
389400

401+
// Move the stream out of the plain text socket, to ensure the stream is properly
402+
// destroyed when the socket is closed.
403+
JSG_REQUIRE(self->connectionStream != kj::none, TypeError,
404+
"The connection was closed before startTls completed.");
405+
IoOwn<kj::RefcountedWrapper<kj::Own<kj::AsyncIoStream>>> wrapper =
406+
KJ_ASSERT_NONNULL(kj::mv(self->connectionStream));
407+
self->connectionStream = kj::none;
408+
390409
auto secureStream = forkedPromise.addBranch().then(
391-
[stream = self->connectionStream->addWrappedRef()]() mutable
392-
-> kj::Own<kj::AsyncIoStream> { return kj::mv(stream); });
410+
[stream = wrapper->addWrappedRef()]() mutable -> kj::Own<kj::AsyncIoStream> {
411+
return kj::mv(stream);
412+
});
393413

394414
return kj::newPromisedStream(kj::mv(secureStream));
395415
})));
@@ -519,13 +539,27 @@ jsg::Ref<Socket> SocketsModule::connect(
519539
}
520540

521541
kj::Own<kj::AsyncIoStream> Socket::takeConnectionStream(jsg::Lock& js) {
542+
// Set this so that if `close` is called after this, that no closure steps are taken and instead
543+
// the `close` is a no-op.
544+
isClosing = true;
545+
522546
// We do not care if the socket was disturbed, we require the user to ensure the socket is not
523547
// being used.
524548
writable->detach(js);
525549
readable->detach(js, true);
526550

551+
// Move the stream out of the socket, to ensure the stream is properly destroyed when the
552+
// caller is done with it.
553+
JSG_REQUIRE(connectionStream != kj::none, TypeError,
554+
"The socket connection is closed or was already taken.");
555+
IoOwn<kj::RefcountedWrapper<kj::Own<kj::AsyncIoStream>>> wrapper =
556+
KJ_ASSERT_NONNULL(kj::mv(connectionStream));
557+
connectionStream = kj::none;
558+
527559
closedResolver.resolve(js);
528-
return connectionStream->addWrappedRef();
560+
561+
// Get a new reference to the wrapped stream via refcounting
562+
return wrapper->addWrappedRef();
529563
}
530564

531565
// Implementation of the custom factory for creating WorkerInterface instances from a socket

src/workerd/api/sockets.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ class Socket: public jsg::Object {
183183
// TODO(cleanup): Combine all the IoOwns here into one, to improve efficiency and make
184184
// shutdown order clearer.
185185

186-
IoOwn<kj::RefcountedWrapper<kj::Own<kj::AsyncIoStream>>> connectionStream;
186+
kj::Maybe<IoOwn<kj::RefcountedWrapper<kj::Own<kj::AsyncIoStream>>>> connectionStream;
187187
jsg::Ref<ReadableStream> readable;
188188
jsg::Ref<WritableStream> writable;
189189
// This fulfiller is used to resolve the `closedPromise` below.

src/workerd/api/system-streams.c++

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,13 +380,13 @@ kj::Own<WritableStreamSink> newSystemStream(
380380
return kj::heap<EncodedAsyncOutputStream>(kj::mv(inner), encoding, context);
381381
}
382382

383-
SystemMultiStream newSystemMultiStream(kj::Own<kj::AsyncIoStream> stream, IoContext& context) {
383+
SystemMultiStream newSystemMultiStream(
384+
kj::RefcountedWrapper<kj::Own<kj::AsyncIoStream>>& stream, IoContext& context) {
384385

385-
auto wrapped = kj::refcountedWrapper(kj::mv(stream));
386386
return {.readable = kj::heap<EncodedAsyncInputStream>(
387-
wrapped->addWrappedRef(), StreamEncoding::IDENTITY, context),
387+
stream.addWrappedRef(), StreamEncoding::IDENTITY, context),
388388
.writable = kj::heap<EncodedAsyncOutputStream>(
389-
wrapped->addWrappedRef(), StreamEncoding::IDENTITY, context)};
389+
stream.addWrappedRef(), StreamEncoding::IDENTITY, context)};
390390
}
391391

392392
ContentEncodingOptions::ContentEncodingOptions(CompatibilityFlags::Reader flags)

src/workerd/api/system-streams.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ struct SystemMultiStream {
3939
};
4040

4141
// A combo ReadableStreamSource and WritableStreamSink.
42-
SystemMultiStream newSystemMultiStream(
43-
kj::Own<kj::AsyncIoStream> stream, IoContext& context = IoContext::current());
42+
SystemMultiStream newSystemMultiStream(kj::RefcountedWrapper<kj::Own<kj::AsyncIoStream>>& stream,
43+
IoContext& context = IoContext::current());
4444

4545
struct ContentEncodingOptions {
4646
bool brotliEnabled = false;

0 commit comments

Comments
 (0)