diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index d9808ffb..740c13e7 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -186,18 +186,30 @@ public abstract class Protocol( private suspend fun onNotification(notification: JSONRPCNotification) { LOGGER.trace { "Received notification: ${notification.method}" } - val function = notificationHandlers[notification.method] + + // Some MCP clients (like Cursor) don't include the "method" key in the params object + // This ensures method is always available in params when it's a JsonObject, + // preventing NPEs during notification routing and processing + val processedNotification = if (notification.params is JsonObject && !notification.params.containsKey("method")) { + notification.copy( + params = JsonObject(notification.params + ("method" to JsonPrimitive(notification.method))) + ) + } else { + notification + } + + val function = notificationHandlers[processedNotification.method] val property = fallbackNotificationHandler val handler = function ?: property if (handler == null) { - LOGGER.trace { "No handler found for notification: ${notification.method}" } + LOGGER.trace { "No handler found for notification: ${processedNotification.method}" } return } try { - handler(notification) + handler(processedNotification) } catch (cause: Throwable) { - LOGGER.error(cause) { "Error handling notification: ${notification.method}" } + LOGGER.error(cause) { "Error handling notification: ${processedNotification.method}" } onError(cause) } } diff --git a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ProtocolNotificationHandlingTest.kt b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ProtocolNotificationHandlingTest.kt new file mode 100644 index 00000000..5f5920cd --- /dev/null +++ b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ProtocolNotificationHandlingTest.kt @@ -0,0 +1,147 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +import io.modelcontextprotocol.kotlin.sdk.* +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.* +import kotlin.test.* + +class ProtocolNotificationHandlingTest { + private lateinit var protocol: TestProtocol + private lateinit var transport: TestTransport + + @BeforeTest + fun setUp() { + protocol = TestProtocol() + transport = TestTransport() + + runBlocking { + protocol.connect(transport) + } + } + + @Test + fun `onNotification adds method key to JsonObject params when missing`() = runTest { + val originalParams = buildJsonObject { put("data", 123) } + val notification = JSONRPCNotification( + method = "test/notification", + params = originalParams + ) + + transport.simulateMessage(notification) + + assertNotNull(protocol.receivedNotification, "Handler should have captured the notification") + val receivedParams = protocol.receivedNotification?.params + assertIs(receivedParams, "Params should be JsonObject") + + assertEquals( + buildJsonObject { + put("data", 123) + put("method", "test/notification") + }, + receivedParams + ) + assertFalse(protocol.errorHandlerCalled, "onError should not be called") + } + + @Test + fun `onNotification should not modify params if JsonObject and method exists`() = runTest { + val originalParams = buildJsonObject { + put("data", 123) + put("method", "test/notification") + } + val notification = JSONRPCNotification( + method = "test/notification", + params = originalParams + ) + + transport.simulateMessage(notification) + + assertNotNull(protocol.receivedNotification, "Handler should have captured the notification") + val receivedParams = protocol.receivedNotification?.params + assertIs(receivedParams, "Params should be JsonObject") + + // Because "method" already exists, it should be unchanged + assertEquals(originalParams, receivedParams) + assertFalse(protocol.errorHandlerCalled, "onError should not be called") + } + + @Test + fun `onNotification should not modify params if JsonArray`() = runTest { + val originalParams = buildJsonArray { add(1); add("test") } + val notification = JSONRPCNotification( + method = "test/notification", + params = originalParams + ) + + transport.simulateMessage(notification) + + assertNotNull(protocol.receivedNotification, "Handler should have captured the notification") + val receivedParams = protocol.receivedNotification?.params + assertIs(receivedParams, "Params should be JsonArray") + // Should remain unmodified + assertEquals(originalParams, receivedParams) + assertFalse(protocol.errorHandlerCalled, "onError should not be called") + } + + @Test + fun `onNotification should handle JsonNull params`() = runTest { + val notification = JSONRPCNotification( + method = "test/notification", + params = JsonNull + ) + + transport.simulateMessage(notification) + + assertNotNull(protocol.receivedNotification, "Handler should have captured the notification") + + // Should remain JsonNull + assertEquals(JsonNull, protocol.receivedNotification?.params) + assertFalse(protocol.errorHandlerCalled, "onError should not be called") + } + + @Test + fun `onNotification should call fallback handler if specific handler not found`() = runTest { + val notification = JSONRPCNotification( + method = "unregistered/notification", + params = buildJsonObject { put("value", true) } + ) + + transport.simulateMessage(notification) + + assertNotNull(protocol.receivedNotification, "Fallback handler should have captured the notification") + assertEquals("unregistered/notification", protocol.receivedNotification?.method) + val receivedParams = protocol.receivedNotification?.params + assertIs(receivedParams) + + // Because we had no specific handler, "method" gets auto-added + assertEquals( + buildJsonObject { + put("value", true) + put("method", "unregistered/notification") + }, + receivedParams + ) + assertFalse(protocol.errorHandlerCalled, "onError should not be called") + } + + @Test + fun `onNotification should call onError if handler throws exception`() = runTest { + val exception = RuntimeException("Handler error!") + protocol.notificationHandlers["error/notification"] = { + throw exception + } + + val notification = JSONRPCNotification( + method = "error/notification", + params = buildJsonObject { put("method", "error/notification") } + ) + + transport.simulateMessage(notification) + + // Because the handler throws, the protocol's onError callback should run + assertNull(protocol.receivedNotification, "Received notification should be null since handler threw") + assertTrue(protocol.errorHandlerCalled, "onError should have been called") + assertSame(exception, protocol.lastError, "onError should receive the correct exception") + } +} diff --git a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/TestProtocol.kt b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/TestProtocol.kt new file mode 100644 index 00000000..56d11a0f --- /dev/null +++ b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/TestProtocol.kt @@ -0,0 +1,29 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification +import io.modelcontextprotocol.kotlin.sdk.Method + +internal class TestProtocol : Protocol(options = null) { + var receivedNotification: JSONRPCNotification? = null + var errorHandlerCalled = false + var lastError: Throwable? = null + + init { + notificationHandlers["test/notification"] = { notification -> + receivedNotification = notification + } + + fallbackNotificationHandler = { notification -> + receivedNotification = notification + } + } + + override fun assertCapabilityForMethod(method: Method) {} + override fun assertNotificationCapability(method: Method) {} + override fun assertRequestHandlerCapability(method: Method) {} + + override fun onError(cause: Throwable) { + errorHandlerCalled = true + lastError = cause + } +} diff --git a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/TestTransport.kt b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/TestTransport.kt new file mode 100644 index 00000000..266b4532 --- /dev/null +++ b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/TestTransport.kt @@ -0,0 +1,43 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.shared.Transport + +/** + * A simple fake Transport for testing, capturing the onMessage callback + * so we can trigger messages manually. + */ +internal class TestTransport : Transport { + private var onCloseCallback: (() -> Unit)? = null + private var onErrorCallback: ((Throwable) -> Unit)? = null + private var onMessageCallback: (suspend (JSONRPCMessage) -> Unit)? = null + + override fun onClose(callback: () -> Unit) { + onCloseCallback = callback + } + + override fun onError(callback: (Throwable) -> Unit) { + onErrorCallback = callback + } + + override fun onMessage(callback: suspend (JSONRPCMessage) -> Unit) { + onMessageCallback = callback + } + + override suspend fun start() { + // no-op for test + } + + override suspend fun close() { + onCloseCallback?.invoke() + } + + override suspend fun send(message: JSONRPCMessage) { + // we don’t need to do anything with outbound messages in these tests, + // unless you want to record them for verification + } + + suspend fun simulateMessage(message: JSONRPCMessage) { + onMessageCallback?.invoke(message) + } +}