diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index 306b633f6fa37..556cd84a21b28 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -731,8 +731,12 @@ private[kafka] abstract class Acceptor(val socketServer: SocketServer, s" sendBufferSize [actual|requested]: [${socketChannel.socket.getSendBufferSize}|$sendBufferSize]" + s" recvBufferSize [actual|requested]: [${socketChannel.socket.getReceiveBufferSize}|$recvBufferSize]") true - } else + } else{ + // connection was rejected (likely due to processor shutdown) - close the socket. + val listenerName = ListenerName.normalised(endPoint.listener) + connectionQuotas.closeChannel(this, listenerName, socketChannel) false + } } /** @@ -1154,6 +1158,10 @@ private[kafka] class Processor( def accept(socketChannel: SocketChannel, mayBlock: Boolean, acceptorBlockedPercentMeter: com.yammer.metrics.core.Meter): Boolean = { + // reject new connections if the processor is shutting down + if (!shouldRun.get()) + return false + val accepted = { if (newConnections.offer(socketChannel)) true diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 6a4b8d8ca672e..05171e41bd7ea 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -934,6 +934,91 @@ class SocketServerTest { verifyRemoteConnectionClosed(conn) } + @Test + def testNoSocketLeakDuringShutdownRaceCondition(): Unit = { + // verify that when acceptor shuts down during connection assignment, no sockets are leaked + val address = localAddress + + // create initial connection to establish baseline + val initialSocket = connect() + sendRequest(initialSocket, producerRequestBytes()) + receiveRequest(server.dataPlaneRequestChannel) + + val initialConnectionCount = server.connectionCount(address) + assertTrue(initialConnectionCount > 0, "Should have at least one connection") + + server.dataPlaneAcceptors.asScala.values.foreach(_.beginShutdown()) + + // give shutdown a small window to start processing + Thread.sleep(100) + + // close the acceptor + server.dataPlaneAcceptors.asScala.values.foreach(_.close()) + + // cleanup initial connection + initialSocket.setSoLinger(true, 0) + initialSocket.close() + + // critical verification for KAFKA-16765: No socket leaks + // all sockets must be properly closed during shutdown, not queued indefinitely + TestUtils.waitUntilTrue(() => server.connectionCount(address) == 0, + "Socket leak detected - connections not closed after shutdown") + } + + @Test + def testAcceptorClosesSocketWhenProcessorRejectsDuringShutdown(): Unit = { + // explicitly simulate the scenario where Acceptor.assignNewConnection attempts to + // assign a SocketChannel after the Processor has begun shutting down. The + // expected behavior is that the assignment is rejected and the socket is closed + // (no leak). + // shutdown the default server created in @BeforeEach to avoid metric collisions + shutdownServerAndMetrics(server) + val testServer = new TestableSocketServer() + try { + testServer.enableRequestProcessing(Map()).get(1, TimeUnit.MINUTES) + val acceptor = testServer.testableAcceptor + val processor = acceptor.processors(0) + + // create a socket channel pair (use SocketChannel.open so getChannel != null) + val socketServer = new ServerSocket(0) + val channel = java.nio.channels.SocketChannel.open(new InetSocketAddress("localhost", socketServer.getLocalPort)) + val serverConn = socketServer.accept() + + try { + // simulate the accept path incrementing the connection count as Acceptor.accept() would + val listenerName = ListenerName.normalised(testServer.config.dataPlaneListeners.head.listener) + val blockedMeter = new org.apache.kafka.server.metrics.KafkaMetricsGroup("kafka.network", "Acceptor").newMeter("blockedPercentMeter", "blocked time", TimeUnit.NANOSECONDS) + testServer.connectionQuotas.inc(listenerName, channel.socket.getInetAddress, blockedMeter) + + // begin processor shutdown so it will reject new connections + processor.beginShutdown() + // wait for processor to acknowledge shutdown; avoid fixed sleeps + TestUtils.waitUntilTrue(() => !processor.shouldRun.get(), "Processor did not begin shutdown") + + // use reflection to invoke the private assignNewConnection to simulate the + // acceptor attempting to assign the just-accepted channel after shutdown + val m = classOf[Acceptor].getDeclaredMethod("assignNewConnection", classOf[SocketChannel], classOf[Processor], java.lang.Boolean.TYPE) + m.setAccessible(true) + val assigned = m.invoke(acceptor, channel, processor, java.lang.Boolean.FALSE.asInstanceOf[Object]).asInstanceOf[Boolean] + + // should be rejected and socket closed + assertFalse(assigned, "assignNewConnection should reject when processor is shutting down") + // channel or its underlying socket must be closed + assertTrue(!channel.isOpen || channel.socket.isClosed, "SocketChannel should be closed after rejected assignment") + + // ensure the server does not report leaked connections + TestUtils.waitUntilTrue(() => testServer.connectionCount(localAddress) == 0, + "Connection should not be leaked after rejected assignment") + } finally { + try { if (channel != null && channel.isOpen) channel.close() } catch { case _: Throwable => } + try { serverConn.close() } catch { case _: Throwable => } + try { socketServer.close() } catch { case _: Throwable => } + } + } finally { + shutdownServerAndMetrics(testServer) + } + } + private def verifyRemoteConnectionClosed(connection: Socket): Unit = { val largeChunkOfBytes = new Array[Byte](1000000) // doing a subsequent send should throw an exception as the connection should be closed. @@ -1504,8 +1589,11 @@ class SocketServerTest { // attempting to send response. Otherwise, the channel should be removed when all completed buffers are processed. // Channel should be closed and removed even if there is a partial buffered request when `hasIncomplete=true` val numRequests = if (responseRequiredIndex >= 0) responseRequiredIndex + 1 else numComplete + // Use a longer timeout for receiveRequest here since under full test-suite runs + // processing can be slower and short timeouts make these tests flaky. (0 until numRequests).foreach { i => - val request = receiveRequest(testableServer.dataPlaneRequestChannel) + // Allow a longer timeout to accommodate slower full-suite runs on loaded CI machines + val request = receiveRequest(testableServer.dataPlaneRequestChannel, timeout = 30000L) if (i == numComplete - 1 && hasIncomplete) truncateBufferedRequest(channel) if (responseRequiredIndex == i)