@@ -156,7 +156,7 @@ jsg::Ref<Socket> setupSocket(jsg::Lock& js,
156156
157157 auto refcountedConnection = kj::refcountedWrapper (kj::mv (connection));
158158 // Initialize the readable/writable streams with the readable/writable sides of an AsyncIoStream.
159- auto sysStreams = newSystemMultiStream (refcountedConnection-> addWrappedRef () , ioContext);
159+ auto sysStreams = newSystemMultiStream (* refcountedConnection, ioContext);
160160 auto readable = js.alloc <ReadableStream>(ioContext, kj::mv (sysStreams.readable ));
161161 auto allowHalfOpen = getAllowHalfOpen (options);
162162 kj::Maybe<jsg::Promise<void >> eofPromise;
@@ -301,12 +301,23 @@ jsg::Promise<void> Socket::close(jsg::Lock& js) {
301301 // Forcibly abort the readable/writable streams.
302302 auto cancelPromise = readable->getController ().cancel (js, kj::none);
303303 auto abortPromise = writable->getController ().abort (js, kj::none);
304+
304305 // The below is effectively `Promise.all(cancelPromise, abortPromise)`
305306 return cancelPromise.then (js, [abortPromise = kj::mv (abortPromise)](jsg::Lock& js) mutable {
306307 return kj::mv (abortPromise);
307308 });
308309 })
309310 .then (js, [this ](jsg::Lock& js) {
311+ // This task needs to destroyed prior to destroying the AsyncIoStream as it is awaiting
312+ // that stream's `whenWriteDisconnected` promise.
313+ watchForDisconnectTask = nullptr ;
314+
315+ // Destroy the tlsStarter which is also keeping the connection open.
316+ { auto _ = kj::mv (tlsStarter); }
317+
318+ // Destroy the connection stream to close the connection.
319+ connectionStream = kj::none;
320+
310321 resolveFulfiller (js, kj::none);
311322 return js.resolvedPromise ();
312323 }).catch_ (js, [this ](jsg::Lock& js, jsg::Value err) { errorHandler (js, kj::mv (err)); });
@@ -347,7 +358,7 @@ jsg::Ref<Socket> Socket::startTls(jsg::Lock& js, jsg::Optional<TlsOptions> tlsOp
347358 auto & context = IoContext::current ();
348359
349360 self->writable ->detach (js);
350- self->readable = self-> readable ->detach (js, true );
361+ self->readable ->detach (js, true );
351362
352363 // We should set this before closedResolver.resolve() in order to give the user
353364 // the option to check if the closed promise is resolved due to upgrade or not.
@@ -387,9 +398,18 @@ jsg::Ref<Socket> Socket::startTls(jsg::Lock& js, jsg::Optional<TlsOptions> tlsOp
387398 };
388399 }));
389400
401+ // Move the stream out of the plain text socket, to ensure the stream is properly
402+ // destroyed when the socket is closed.
403+ JSG_REQUIRE (self->connectionStream != kj::none, TypeError,
404+ " The connection was closed before startTls completed." );
405+ IoOwn<kj::RefcountedWrapper<kj::Own<kj::AsyncIoStream>>> wrapper =
406+ KJ_ASSERT_NONNULL (kj::mv (self->connectionStream ));
407+ self->connectionStream = kj::none;
408+
390409 auto secureStream = forkedPromise.addBranch ().then (
391- [stream = self->connectionStream ->addWrappedRef ()]() mutable
392- -> kj::Own<kj::AsyncIoStream> { return kj::mv (stream); });
410+ [stream = wrapper->addWrappedRef ()]() mutable -> kj::Own<kj::AsyncIoStream> {
411+ return kj::mv (stream);
412+ });
393413
394414 return kj::newPromisedStream (kj::mv (secureStream));
395415 })));
@@ -519,13 +539,27 @@ jsg::Ref<Socket> SocketsModule::connect(
519539}
520540
521541kj::Own<kj::AsyncIoStream> Socket::takeConnectionStream (jsg::Lock& js) {
542+ // Set this so that if `close` is called after this, that no closure steps are taken and instead
543+ // the `close` is a no-op.
544+ isClosing = true ;
545+
522546 // We do not care if the socket was disturbed, we require the user to ensure the socket is not
523547 // being used.
524548 writable->detach (js);
525549 readable->detach (js, true );
526550
551+ // Move the stream out of the socket, to ensure the stream is properly destroyed when the
552+ // caller is done with it.
553+ JSG_REQUIRE (connectionStream != kj::none, TypeError,
554+ " The socket connection is closed or was already taken." );
555+ IoOwn<kj::RefcountedWrapper<kj::Own<kj::AsyncIoStream>>> wrapper =
556+ KJ_ASSERT_NONNULL (kj::mv (connectionStream));
557+ connectionStream = kj::none;
558+
527559 closedResolver.resolve (js);
528- return connectionStream->addWrappedRef ();
560+
561+ // Get a new reference to the wrapped stream via refcounting
562+ return wrapper->addWrappedRef ();
529563}
530564
531565// Implementation of the custom factory for creating WorkerInterface instances from a socket
0 commit comments