From 764d773529ae3077f86b1f2a30127917ac0bbc3b Mon Sep 17 00:00:00 2001 From: SeanChinJunKai Date: Sat, 19 Apr 2025 22:05:14 +0800 Subject: [PATCH 1/8] feat: allow JSONRPCResponse to have nullable id --- .../kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt index f84233e8..f7dbc797 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt @@ -240,7 +240,7 @@ public data class JSONRPCNotification( */ @Serializable public class JSONRPCResponse( - public val id: RequestId, + public val id: RequestId?, public val jsonrpc: String = JSONRPC_VERSION, public val result: RequestResult? = null, public val error: JSONRPCError? = null, From 096b28ffa6f2d5570f2a58435f4f8faaf25a0ebb Mon Sep 17 00:00:00 2001 From: SeanChinJunKai Date: Mon, 28 Apr 2025 00:16:00 +0800 Subject: [PATCH 2/8] feat: add StreamableHttpTransport for server --- .../server/StreamableHttpServerTransport.kt | 304 ++++++++++++++++++ 1 file changed, 304 insertions(+) create mode 100644 src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt new file mode 100644 index 00000000..afcf1973 --- /dev/null +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -0,0 +1,304 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.ktor.http.* +import io.ktor.server.application.* +import io.ktor.server.request.* +import io.ktor.server.response.* +import io.ktor.server.sse.* +import io.modelcontextprotocol.kotlin.sdk.* +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.serialization.encodeToString +import kotlin.collections.HashMap +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +@OptIn(ExperimentalAtomicApi::class) +public class StreamableHttpServerTransport( + private val isStateful: Boolean = false, + private val enableJSONResponse: Boolean = false, +): AbstractTransport() { + private val standalone = "standalone" + private val streamMapping: HashMap = hashMapOf() + private val requestToStreamMapping: HashMap = hashMapOf() + private val requestResponseMapping: HashMap = hashMapOf() + private val callMapping: HashMap = hashMapOf() + private val started: AtomicBoolean = AtomicBoolean(false) + private val initialized: AtomicBoolean = AtomicBoolean(false) + + public var sessionId: String? = null + private set + + override suspend fun start() { + if (!started.compareAndSet(false, true)) { + error("StreamableHttpServerTransport already started! If using Server class, note that connect() calls start() automatically.") + } + } + + override suspend fun send(message: JSONRPCMessage) { + var requestId: RequestId? = null + + if (message is JSONRPCResponse) { + requestId = message.id + } + + if (requestId == null) { + val standaloneSSE = streamMapping[standalone] ?: return + + standaloneSSE.send( + event = "message", + data = McpJson.encodeToString(message), + ) + return + } + + val streamId = requestToStreamMapping[requestId] ?: error("No connection established for request id $requestId") + val correspondingStream = streamMapping[streamId] ?: error("No connection established for request id $requestId") + val correspondingCall = callMapping[streamId] ?: error("No connection established for request id $requestId") + + if (!enableJSONResponse) { + correspondingStream.send( + event = "message", + data = McpJson.encodeToString(message), + ) + } + + requestResponseMapping[requestId] = message + val relatedIds = requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key } + val allResponsesReady = relatedIds.all { requestResponseMapping[it] != null } + + if (allResponsesReady) { + if (enableJSONResponse) { + correspondingCall.response.headers.append(ContentType.toString(), ContentType.Application.Json.toString()) + correspondingCall.response.status(HttpStatusCode.OK) + if (sessionId != null) { + correspondingCall.response.header("Mcp-Session-Id", sessionId!!) + } + val responses = relatedIds.map{ requestResponseMapping[it] } + if (responses.size == 1) { + correspondingCall.respond(responses[0]!!) + } else { + correspondingCall.respond(responses) + } + callMapping.remove(streamId) + } else { + correspondingStream.close() + streamMapping.remove(streamId) + } + + for (id in relatedIds) { + requestToStreamMapping.remove(id) + requestResponseMapping.remove(id) + } + } + + } + + override suspend fun close() { + streamMapping.values.forEach { + it.close() + } + streamMapping.clear() + requestToStreamMapping.clear() + requestResponseMapping.clear() + // TODO Check if we need to clear the callMapping or if call timeout after awhile + _onClose.invoke() + } + + @OptIn(ExperimentalUuidApi::class) + public suspend fun handlePostRequest(call: ApplicationCall, session: ServerSSESession) { + try { + val acceptHeader = call.request.headers["Accept"]?.split(",") ?: listOf() + + if (!acceptHeader.contains("text/event-stream") || !acceptHeader.contains("application/json")) { + call.response.status(HttpStatusCode.NotAcceptable) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = "Not Acceptable: Client must accept both application/json and text/event-stream" + ) + ) + ) + return + } + + val contentType = call.request.contentType() + if (contentType != ContentType.Application.Json) { + call.response.status(HttpStatusCode.UnsupportedMediaType) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = "Unsupported Media Type: Content-Type must be application/json" + ) + ) + ) + return + } + + val body = call.receiveText() + val messages = mutableListOf() + + if (body.startsWith("[")) { + messages.addAll(McpJson.decodeFromString>(body)) + } else { + messages.add(McpJson.decodeFromString(body)) + } + + val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == "initialize" } + if (hasInitializationRequest) { + if (initialized.load() && sessionId != null) { + call.response.status(HttpStatusCode.BadRequest) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Defined.InvalidRequest, + message = "Invalid Request: Server already initialized" + ) + ) + ) + return + } + + if (messages.size > 1) { + call.response.status(HttpStatusCode.BadRequest) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Defined.InvalidRequest, + message = "Invalid Request: Only one initialization request is allowed" + ) + ) + ) + return + } + + if (isStateful) { + sessionId = Uuid.random().toString() + } + initialized.store(true) + + if (!validateSession(call)) { + return + } + + val hasRequests = messages.any { it is JSONRPCRequest } + val streamId = Uuid.random().toString() + + if (!hasRequests){ + call.respondNullable(HttpStatusCode.Accepted) + } else { + if (!enableJSONResponse) { + call.response.headers.append(ContentType.toString(), ContentType.Text.EventStream.toString()) + + if (sessionId != null) { + call.response.header("Mcp-Session-Id", sessionId!!) + } + } + + for (message in messages) { + if (message is JSONRPCRequest) { + streamMapping[streamId] = session + callMapping[streamId] = call + requestToStreamMapping[message.id] = streamId + } + } + } + for (message in messages) { + _onMessage.invoke(message) + } + } + + } catch (e: Exception) { + call.response.status(HttpStatusCode.BadRequest) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = e.message ?: "Parse error" + ) + ) + ) + _onError.invoke(e) + } + } + + public suspend fun handleGetRequest(call: ApplicationCall, session: ServerSSESession) { + val acceptHeader = call.request.headers["Accept"]?.split(",") ?: listOf() + if (!acceptHeader.contains("text/event-stream")) { + call.response.status(HttpStatusCode.NotAcceptable) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = "Not Acceptable: Client must accept text/event-stream" + ) + ) + ) + } + + if (!validateSession(call)) { + return + } + + if (sessionId != null) { + call.response.header("Mcp-Session-Id", sessionId!!) + } + + if (streamMapping[standalone] != null) { + call.response.status(HttpStatusCode.Conflict) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = "Conflict: Only one SSE stream is allowed per session" + ) + ) + ) + session.close() + return + } + + // TODO: Equivalent of typescript res.writeHead(200, headers).flushHeaders(); + streamMapping[standalone] = session + } + + public suspend fun handleDeleteRequest(call: ApplicationCall) { + if (!validateSession(call)) { + return + } + close() + call.respondNullable(HttpStatusCode.OK) + } + + public suspend fun validateSession(call: ApplicationCall): Boolean { + if (sessionId == null) { + return true + } + + if (!initialized.load()) { + call.response.status(HttpStatusCode.BadRequest) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = "Bad Request: Server not initialized" + ) + ) + ) + return false + } + return true + } +} From fe919360c66357b45b3022587e899f9654c949f9 Mon Sep 17 00:00:00 2001 From: SeanChinJunKai Date: Sat, 3 May 2025 19:09:02 +0800 Subject: [PATCH 3/8] feat: minor refactoring --- .../server/StreamableHttpServerTransport.kt | 208 +++++++++++------- .../kotlin/sdk/shared/Constants.kt | 3 + .../kotlin/integration/SseIntegrationTest.kt | 0 3 files changed, 126 insertions(+), 85 deletions(-) create mode 100644 src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Constants.kt create mode 100644 src/jvmTest/kotlin/integration/SseIntegrationTest.kt diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index afcf1973..becad394 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -7,14 +7,24 @@ import io.ktor.server.response.* import io.ktor.server.sse.* import io.modelcontextprotocol.kotlin.sdk.* import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SESSION_ID import io.modelcontextprotocol.kotlin.sdk.shared.McpJson import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.decodeFromJsonElement import kotlin.collections.HashMap import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid +/** + * Server transport for StreamableHttp: this allows server to respond to GET, POST and DELETE requests. Server can optionally make use of Server-Sent Events (SSE) to stream multiple server messages. + * + * Creates a new StreamableHttp server transport. + */ @OptIn(ExperimentalAtomicApi::class) public class StreamableHttpServerTransport( private val isStateful: Boolean = false, @@ -55,7 +65,8 @@ public class StreamableHttpServerTransport( } val streamId = requestToStreamMapping[requestId] ?: error("No connection established for request id $requestId") - val correspondingStream = streamMapping[streamId] ?: error("No connection established for request id $requestId") + val correspondingStream = + streamMapping[streamId] ?: error("No connection established for request id $requestId") val correspondingCall = callMapping[streamId] ?: error("No connection established for request id $requestId") if (!enableJSONResponse) { @@ -66,32 +77,33 @@ public class StreamableHttpServerTransport( } requestResponseMapping[requestId] = message - val relatedIds = requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key } + val relatedIds = + requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key } val allResponsesReady = relatedIds.all { requestResponseMapping[it] != null } - if (allResponsesReady) { - if (enableJSONResponse) { - correspondingCall.response.headers.append(ContentType.toString(), ContentType.Application.Json.toString()) - correspondingCall.response.status(HttpStatusCode.OK) - if (sessionId != null) { - correspondingCall.response.header("Mcp-Session-Id", sessionId!!) - } - val responses = relatedIds.map{ requestResponseMapping[it] } - if (responses.size == 1) { - correspondingCall.respond(responses[0]!!) - } else { - correspondingCall.respond(responses) - } - callMapping.remove(streamId) + if (!allResponsesReady) return + + if (enableJSONResponse) { + correspondingCall.response.headers.append(ContentType.toString(), ContentType.Application.Json.toString()) + correspondingCall.response.status(HttpStatusCode.OK) + if (sessionId != null) { + correspondingCall.response.header(MCP_SESSION_ID, sessionId!!) + } + val responses = relatedIds.map { requestResponseMapping[it] } + if (responses.size == 1) { + correspondingCall.respond(responses[0]!!) } else { - correspondingStream.close() - streamMapping.remove(streamId) + correspondingCall.respond(responses) } + callMapping.remove(streamId) + } else { + correspondingStream.close() + streamMapping.remove(streamId) + } - for (id in relatedIds) { - requestToStreamMapping.remove(id) - requestResponseMapping.remove(id) - } + for (id in relatedIds) { + requestToStreamMapping.remove(id) + requestResponseMapping.remove(id) } } @@ -110,47 +122,13 @@ public class StreamableHttpServerTransport( @OptIn(ExperimentalUuidApi::class) public suspend fun handlePostRequest(call: ApplicationCall, session: ServerSSESession) { try { - val acceptHeader = call.request.headers["Accept"]?.split(",") ?: listOf() + if (!validateHeaders(call)) return - if (!acceptHeader.contains("text/event-stream") || !acceptHeader.contains("application/json")) { - call.response.status(HttpStatusCode.NotAcceptable) - call.respond( - JSONRPCResponse( - id = null, - error = JSONRPCError( - code = ErrorCode.Unknown(-32000), - message = "Not Acceptable: Client must accept both application/json and text/event-stream" - ) - ) - ) - return - } + val messages = parseBody(call) - val contentType = call.request.contentType() - if (contentType != ContentType.Application.Json) { - call.response.status(HttpStatusCode.UnsupportedMediaType) - call.respond( - JSONRPCResponse( - id = null, - error = JSONRPCError( - code = ErrorCode.Unknown(-32000), - message = "Unsupported Media Type: Content-Type must be application/json" - ) - ) - ) - return - } - - val body = call.receiveText() - val messages = mutableListOf() - - if (body.startsWith("[")) { - messages.addAll(McpJson.decodeFromString>(body)) - } else { - messages.add(McpJson.decodeFromString(body)) - } + if (messages.isEmpty()) return - val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == "initialize" } + val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == Method.Defined.Initialize.value } if (hasInitializationRequest) { if (initialized.load() && sessionId != null) { call.response.status(HttpStatusCode.BadRequest) @@ -184,38 +162,37 @@ public class StreamableHttpServerTransport( sessionId = Uuid.random().toString() } initialized.store(true) + } - if (!validateSession(call)) { - return - } - - val hasRequests = messages.any { it is JSONRPCRequest } - val streamId = Uuid.random().toString() + if (!validateSession(call)) { + return + } - if (!hasRequests){ - call.respondNullable(HttpStatusCode.Accepted) - } else { - if (!enableJSONResponse) { - call.response.headers.append(ContentType.toString(), ContentType.Text.EventStream.toString()) + val hasRequests = messages.any { it is JSONRPCRequest } + val streamId = Uuid.random().toString() - if (sessionId != null) { - call.response.header("Mcp-Session-Id", sessionId!!) - } - } + if (!hasRequests) { + call.respondNullable(HttpStatusCode.Accepted) + } else { + if (!enableJSONResponse) { + call.response.headers.append(ContentType.toString(), ContentType.Text.EventStream.toString()) - for (message in messages) { - if (message is JSONRPCRequest) { - streamMapping[streamId] = session - callMapping[streamId] = call - requestToStreamMapping[message.id] = streamId - } + if (sessionId != null) { + call.response.header(MCP_SESSION_ID, sessionId!!) } } + for (message in messages) { - _onMessage.invoke(message) + if (message is JSONRPCRequest) { + streamMapping[streamId] = session + callMapping[streamId] = call + requestToStreamMapping[message.id] = streamId + } } } - + for (message in messages) { + _onMessage.invoke(message) + } } catch (e: Exception) { call.response.status(HttpStatusCode.BadRequest) call.respond( @@ -251,7 +228,7 @@ public class StreamableHttpServerTransport( } if (sessionId != null) { - call.response.header("Mcp-Session-Id", sessionId!!) + call.response.header(MCP_SESSION_ID, sessionId!!) } if (streamMapping[standalone] != null) { @@ -281,7 +258,7 @@ public class StreamableHttpServerTransport( call.respondNullable(HttpStatusCode.OK) } - public suspend fun validateSession(call: ApplicationCall): Boolean { + private suspend fun validateSession(call: ApplicationCall): Boolean { if (sessionId == null) { return true } @@ -301,4 +278,65 @@ public class StreamableHttpServerTransport( } return true } + + private suspend fun validateHeaders(call: ApplicationCall): Boolean { + val acceptHeader = call.request.headers["Accept"]?.split(",") ?: listOf() + + if (!acceptHeader.contains("text/event-stream") || !acceptHeader.contains("application/json")) { + call.response.status(HttpStatusCode.NotAcceptable) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = "Not Acceptable: Client must accept both application/json and text/event-stream" + ) + ) + ) + return false + } + + val contentType = call.request.contentType() + if (contentType != ContentType.Application.Json) { + call.response.status(HttpStatusCode.UnsupportedMediaType) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = "Unsupported Media Type: Content-Type must be application/json" + ) + ) + ) + return false + } + + return true + } + + private suspend fun parseBody( + call: ApplicationCall, + ): List { + val messages = mutableListOf() + when (val body = call.receive()) { + is JsonObject -> messages.add(McpJson.decodeFromJsonElement(body)) + is JsonArray -> messages.addAll(McpJson.decodeFromJsonElement>(body)) + else -> { + call.response.status(HttpStatusCode.BadRequest) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Defined.InvalidRequest, + message = "Invalid Request: Server already initialized" + ) + ) + ) + return listOf() + } + } + return messages + } + + } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Constants.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Constants.kt new file mode 100644 index 00000000..727aebec --- /dev/null +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Constants.kt @@ -0,0 +1,3 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +internal const val MCP_SESSION_ID = "mcp-session-id" \ No newline at end of file diff --git a/src/jvmTest/kotlin/integration/SseIntegrationTest.kt b/src/jvmTest/kotlin/integration/SseIntegrationTest.kt new file mode 100644 index 00000000..e69de29b From b7937f9b1084f1bc8206149d057f96dd624102fc Mon Sep 17 00:00:00 2001 From: devcrocod Date: Wed, 9 Jul 2025 12:47:47 +0200 Subject: [PATCH 4/8] api dump for StreamableHttpTransport --- api/kotlin-sdk.api | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/api/kotlin-sdk.api b/api/kotlin-sdk.api index 6b711675..8eff00d2 100644 --- a/api/kotlin-sdk.api +++ b/api/kotlin-sdk.api @@ -3034,6 +3034,19 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTranspor public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } +public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public fun ()V + public fun (ZZ)V + public synthetic fun (ZZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getSessionId ()Ljava/lang/String; + public final fun handleDeleteRequest (Lio/ktor/server/application/ApplicationCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun handleGetRequest (Lio/ktor/server/application/ApplicationCall;Lio/ktor/server/sse/ServerSSESession;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun handlePostRequest (Lio/ktor/server/application/ApplicationCall;Lio/ktor/server/sse/ServerSSESession;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public final class io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensionsKt { public static final fun mcpWebSocket (Lio/ktor/server/routing/Route;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;)V public static final fun mcpWebSocket (Lio/ktor/server/routing/Route;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;)V From 2737ae44e207058cf5bac3c91fb284736a414bdd Mon Sep 17 00:00:00 2001 From: devcrocod Date: Thu, 10 Jul 2025 03:40:20 +0200 Subject: [PATCH 5/8] Fix and refactor StreamableHttpTransport - add EventStore and support "Last-Event-ID" - Origin check - thread safety --- api/kotlin-sdk.api | 26 +- .../server/StreamableHttpServerTransport.kt | 596 +++++++++++------- .../kotlin/sdk/shared/Constants.kt | 3 - .../kotlin/sdk/shared/Protocol.kt | 2 +- .../modelcontextprotocol/kotlin/sdk/types.kt | 12 +- 5 files changed, 394 insertions(+), 245 deletions(-) delete mode 100644 src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Constants.kt diff --git a/api/kotlin-sdk.api b/api/kotlin-sdk.api index 8eff00d2..1cb7755e 100644 --- a/api/kotlin-sdk.api +++ b/api/kotlin-sdk.api @@ -959,19 +959,12 @@ public final class io/modelcontextprotocol/kotlin/sdk/InitializedNotification$Co public final class io/modelcontextprotocol/kotlin/sdk/JSONRPCError : io/modelcontextprotocol/kotlin/sdk/JSONRPCMessage { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError$Companion; - public fun (Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)V - public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/ErrorCode; - public final fun component2 ()Ljava/lang/String; - public final fun component3 ()Lkotlinx/serialization/json/JsonObject; - public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError; - public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError; - public fun equals (Ljava/lang/Object;)Z + public fun (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getCode ()Lio/modelcontextprotocol/kotlin/sdk/ErrorCode; public final fun getData ()Lkotlinx/serialization/json/JsonObject; + public final fun getId ()Lio/modelcontextprotocol/kotlin/sdk/RequestId; public final fun getMessage ()Ljava/lang/String; - public fun hashCode ()I - public fun toString ()Ljava/lang/String; } public final synthetic class io/modelcontextprotocol/kotlin/sdk/JSONRPCError$$serializer : kotlinx/serialization/internal/GeneratedSerializer { @@ -2922,6 +2915,11 @@ public final class io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorCli public static synthetic fun mcpWebSocketTransport$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport; } +public abstract interface class io/modelcontextprotocol/kotlin/sdk/server/EventStore { + public abstract fun replayEventsAfter (Ljava/lang/String;Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun storeEvent (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { public static final fun MCP (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function0;)V public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function0;)V @@ -3035,15 +3033,19 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTranspor } public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public static final field STANDALONE_SSE_STREAM_ID Ljava/lang/String; public fun ()V - public fun (ZZ)V - public synthetic fun (ZZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;)V + public synthetic fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun getSessionId ()Ljava/lang/String; public final fun handleDeleteRequest (Lio/ktor/server/application/ApplicationCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun handleGetRequest (Lio/ktor/server/application/ApplicationCall;Lio/ktor/server/sse/ServerSSESession;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun handlePostRequest (Lio/ktor/server/application/ApplicationCall;Lio/ktor/server/sse/ServerSSESession;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun handleRequest (Lio/ktor/server/application/ApplicationCall;Lio/ktor/server/sse/ServerSSESession;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun setSessionIdGenerator (Lkotlin/jvm/functions/Function0;)V + public final fun setSessionInitialized (Lkotlin/jvm/functions/Function1;)V public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index becad394..143f4d24 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -1,342 +1,490 @@ package io.modelcontextprotocol.kotlin.sdk.server -import io.ktor.http.* -import io.ktor.server.application.* -import io.ktor.server.request.* -import io.ktor.server.response.* -import io.ktor.server.sse.* -import io.modelcontextprotocol.kotlin.sdk.* +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.ApplicationCall +import io.ktor.server.request.contentType +import io.ktor.server.request.host +import io.ktor.server.request.httpMethod +import io.ktor.server.request.receive +import io.ktor.server.response.header +import io.ktor.server.response.respond +import io.ktor.server.response.respondNullable +import io.ktor.server.sse.ServerSSESession +import io.ktor.util.collections.ConcurrentMap +import io.modelcontextprotocol.kotlin.sdk.ErrorCode +import io.modelcontextprotocol.kotlin.sdk.JSONRPCError +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.RequestId +import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport -import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SESSION_ID import io.modelcontextprotocol.kotlin.sdk.shared.McpJson -import kotlinx.serialization.encodeToString +import kotlinx.coroutines.job +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.serialization.json.JsonArray import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.decodeFromJsonElement -import kotlin.collections.HashMap import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid +private const val MCP_SESSION_ID = "mcp-session-id" +private const val MCP_PROTOCOL_VERSION = "mcp-protocol-version" +private const val LAST_EVENT_ID = "Last-Event-ID" + +/** + * Interface for resumability support via event storage + */ +public interface EventStore { + /** + * Stores an event for later retrieval + * @param streamId ID of the stream the event belongs to + * @param message the JSON-RPC message to store + * @return the generated event ID for the stored event + */ + public suspend fun storeEvent(streamId: String, message: JSONRPCMessage): String + + /** + * Replays events after the specified event ID + * @param lastEventId The last event ID that was received + * @param sender Function to send events + * @return The stream ID for the replayed events + */ + public suspend fun replayEventsAfter( + lastEventId: String, + sender: suspend (eventId: String, message: JSONRPCMessage) -> Unit + ): String +} + +/** + * Simple holder for an active stream + * SSE session and corresponding call + */ +private data class ActiveStream(val sse: ServerSSESession, val call: ApplicationCall) + /** - * Server transport for StreamableHttp: this allows server to respond to GET, POST and DELETE requests. Server can optionally make use of Server-Sent Events (SSE) to stream multiple server messages. + * Server transport for StreamableHttp: this allows the server to respond to GET, POST and DELETE requests. + * Server can optionally make use of Server-Sent Events (SSE) to stream multiple server messages. * * Creates a new StreamableHttp server transport. + * + * @param enableJsonResponse If true, the server will return JSON responses instead of starting an SSE stream. + * This can be useful for simple request/response scenarios without streaming. + * Default is false (SSE streams are preferred). + * @param enableDnsRebindingProtection Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + * Default is false for backwards compatibility. + * @param allowedHosts List of allowed host header values for DNS rebinding protection. + * If not specified, host validation is disabled. + * @param allowedOrigins List of allowed origin header values for DNS rebinding protection. + * If not specified, origin validation is disabled. + * @param eventStore Event store for resumability support. + * If provided, resumability will be enabled, allowing clients to reconnect and resume messages */ -@OptIn(ExperimentalAtomicApi::class) +@OptIn(ExperimentalAtomicApi::class, ExperimentalUuidApi::class) public class StreamableHttpServerTransport( - private val isStateful: Boolean = false, - private val enableJSONResponse: Boolean = false, -): AbstractTransport() { - private val standalone = "standalone" - private val streamMapping: HashMap = hashMapOf() - private val requestToStreamMapping: HashMap = hashMapOf() - private val requestResponseMapping: HashMap = hashMapOf() - private val callMapping: HashMap = hashMapOf() + private val enableJsonResponse: Boolean = false, + private val enableDnsRebindingProtection: Boolean = false, + private val allowedHosts: List? = null, + private val allowedOrigins: List? = null, + private val eventStore: EventStore? = null, +) : AbstractTransport() { + private var onSessionInitialized: ((String) -> Unit)? = null + private var sessionIdGenerator: (() -> String)? = { Uuid.random().toString() } + + private val streams: ConcurrentMap = ConcurrentMap() + private val requestToStream: ConcurrentMap = ConcurrentMap() + private val responses: ConcurrentMap = ConcurrentMap() + private val started: AtomicBoolean = AtomicBoolean(false) private val initialized: AtomicBoolean = AtomicBoolean(false) + private val sessionMutex = Mutex() + private val streamMutex = Mutex() + public var sessionId: String? = null private set + /** + * A callback for session initialization events + * This is called when the server initializes a new session. + * Useful in cases when you need to register multiple mcp sessions + * and need to keep track of them. + */ + public fun setSessionInitialized(block: ((String) -> Unit)?) { + onSessionInitialized = block + } + + /** + * Function that generates a session ID for the transport. + * The session ID SHOULD be globally unique and cryptographically secure + * (e.g., a securely generated UUID) + */ + public fun setSessionIdGenerator(block: (() -> String)?) { + sessionIdGenerator = block + } + override suspend fun start() { - if (!started.compareAndSet(false, true)) { - error("StreamableHttpServerTransport already started! If using Server class, note that connect() calls start() automatically.") + check(started.compareAndSet(expectedValue = false, newValue = true)) { + "StreamableHttpServerTransport already started! If using Server class, note that connect() calls start() automatically." } } override suspend fun send(message: JSONRPCMessage) { - var requestId: RequestId? = null - - if (message is JSONRPCResponse) { - requestId = message.id + val requestId: RequestId? = when (message) { + is JSONRPCResponse -> message.id + is JSONRPCError -> message.id + else -> null } + // Standalone SSE stream if (requestId == null) { - val standaloneSSE = streamMapping[standalone] ?: return - - standaloneSSE.send( - event = "message", - data = McpJson.encodeToString(message), - ) + require(message !is JSONRPCResponse && message !is JSONRPCError) { + "Cannot send a response on a standalone SSE stream unless resuming a previous client request" + } + val standaloneStream = streams[STANDALONE_SSE_STREAM_ID] ?: return + emitOnStream(STANDALONE_SSE_STREAM_ID, standaloneStream, message) return } - val streamId = requestToStreamMapping[requestId] ?: error("No connection established for request id $requestId") - val correspondingStream = - streamMapping[streamId] ?: error("No connection established for request id $requestId") - val correspondingCall = callMapping[streamId] ?: error("No connection established for request id $requestId") + val streamId = requestToStream[requestId] ?: error("No connection established for request id $requestId") + val activeStream = streams[streamId] ?: error("No connection established for request id $requestId") - if (!enableJSONResponse) { - correspondingStream.send( - event = "message", - data = McpJson.encodeToString(message), - ) + if (!enableJsonResponse) { + emitOnStream(streamId, activeStream, message) } - requestResponseMapping[requestId] = message - val relatedIds = - requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key } - val allResponsesReady = relatedIds.all { requestResponseMapping[it] != null } + val isTerminal = message is JSONRPCResponse || message is JSONRPCError + if (!isTerminal) return - if (!allResponsesReady) return + responses[requestId] = message + val relatedIds = requestToStream.filterValues { it == streamId }.keys - if (enableJSONResponse) { - correspondingCall.response.headers.append(ContentType.toString(), ContentType.Application.Json.toString()) - correspondingCall.response.status(HttpStatusCode.OK) - if (sessionId != null) { - correspondingCall.response.header(MCP_SESSION_ID, sessionId!!) - } - val responses = relatedIds.map { requestResponseMapping[it] } - if (responses.size == 1) { - correspondingCall.respond(responses[0]!!) + if (relatedIds.any { it !in responses }) return + + streamMutex.withLock { + val active = streams[streamId] ?: return + val call = active.call + if (enableJsonResponse) { + call.response.header(HttpHeaders.ContentType, ContentType.Application.Json.toString()) + sessionId?.let { call.response.header(MCP_SESSION_ID, it) } + val payload = relatedIds.mapNotNull(responses::remove) + call.respond(if (payload.size == 1) payload.first() else payload) } else { - correspondingCall.respond(responses) + active.sse.close() } - callMapping.remove(streamId) - } else { - correspondingStream.close() - streamMapping.remove(streamId) + streams.remove(streamId) + relatedIds.forEach { requestToStream -= it } } + } - for (id in relatedIds) { - requestToStreamMapping.remove(id) - requestResponseMapping.remove(id) + override suspend fun close() { + streamMutex.withLock { + streams.values.forEach { + try { + it.sse.close() + } catch (_: Exception) { + } + } + streams.clear() + requestToStream.clear() + responses.clear() } - + _onClose() } - override suspend fun close() { - streamMapping.values.forEach { - it.close() + /** + * Handles an incoming HTTP request: GET, POST or DELETE + */ + public suspend fun handleRequest(call: ApplicationCall, session: ServerSSESession) { + validateHeaders(call)?.let { reason -> + call.reject(HttpStatusCode.Forbidden, ErrorCode.Unknown(-32000), reason) + _onError(Error(reason)) + return + } + + when (call.request.httpMethod) { + HttpMethod.Post -> handlePostRequest(call, session) + HttpMethod.Get -> handleGetRequest(call, session) + HttpMethod.Delete -> handleDeleteRequest(call) + else -> call.run { + response.header(HttpHeaders.Allow, "GET, POST, DELETE") + reject(HttpStatusCode.MethodNotAllowed, ErrorCode.Unknown(-32000), "Method not allowed.") + } } - streamMapping.clear() - requestToStreamMapping.clear() - requestResponseMapping.clear() - // TODO Check if we need to clear the callMapping or if call timeout after awhile - _onClose.invoke() } - @OptIn(ExperimentalUuidApi::class) - public suspend fun handlePostRequest(call: ApplicationCall, session: ServerSSESession) { + public suspend fun handlePostRequest(call: ApplicationCall, sse: ServerSSESession) { try { - if (!validateHeaders(call)) return + val acceptHeader = call.request.headers[HttpHeaders.Accept] + val acceptsEventStream = acceptHeader.accepts(ContentType.Text.EventStream) + val acceptsJson = acceptHeader.accepts(ContentType.Application.Json) + + if (!acceptsEventStream || !acceptsJson) { + call.reject( + HttpStatusCode.NotAcceptable, ErrorCode.Unknown(-32000), + "Not Acceptable: Client must accept both application/json and text/event-stream" + ) + return + } - val messages = parseBody(call) + if (call.request.contentType() != ContentType.Application.Json) { + call.reject( + HttpStatusCode.UnsupportedMediaType, ErrorCode.Unknown(-32000), + "Unsupported Media Type: Content-Type must be application/json" + ) + return + } - if (messages.isEmpty()) return + val messages = parseBody(call) ?: return + val isInitializationRequest = messages.any { + it is JSONRPCRequest && it.method == Method.Defined.Initialize.value + } - val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == Method.Defined.Initialize.value } - if (hasInitializationRequest) { + if (isInitializationRequest) { if (initialized.load() && sessionId != null) { - call.response.status(HttpStatusCode.BadRequest) - call.respond( - JSONRPCResponse( - id = null, - error = JSONRPCError( - code = ErrorCode.Defined.InvalidRequest, - message = "Invalid Request: Server already initialized" - ) - ) + call.reject( + HttpStatusCode.BadRequest, ErrorCode.Defined.InvalidRequest, + "Invalid Request: Server already initialized" ) return } - if (messages.size > 1) { - call.response.status(HttpStatusCode.BadRequest) - call.respond( - JSONRPCResponse( - id = null, - error = JSONRPCError( - code = ErrorCode.Defined.InvalidRequest, - message = "Invalid Request: Only one initialization request is allowed" - ) - ) + call.reject( + HttpStatusCode.BadRequest, ErrorCode.Defined.InvalidRequest, + "Invalid Request: Only one initialization request is allowed" ) return } - if (isStateful) { - sessionId = Uuid.random().toString() + sessionMutex.withLock { + if (sessionId != null) return@withLock + sessionId = sessionIdGenerator?.invoke() + initialized.store(true) + sessionId?.let { onSessionInitialized?.invoke(it) } } - initialized.store(true) + } else { + if (!validateSession(call) || !validateProtocolVersion(call)) return } - if (!validateSession(call)) { + val hasRequests = messages.any { it is JSONRPCRequest } + if (!hasRequests) { + call.respondNullable(status = HttpStatusCode.Accepted, message = null) + messages.forEach { message -> _onMessage(message) } return } - val hasRequests = messages.any { it is JSONRPCRequest } val streamId = Uuid.random().toString() + if (!enableJsonResponse) { + call.appendSseHeaders() + sse.send(data = "") // flush headers immediately + } - if (!hasRequests) { - call.respondNullable(HttpStatusCode.Accepted) - } else { - if (!enableJSONResponse) { - call.response.headers.append(ContentType.toString(), ContentType.Text.EventStream.toString()) + streamMutex.withLock { + streams[streamId] = ActiveStream(sse, call) + messages.filterIsInstance().forEach { requestToStream[it.id] = streamId } + } + sse.coroutineContext.job.invokeOnCompletion { streams -= streamId } - if (sessionId != null) { - call.response.header(MCP_SESSION_ID, sessionId!!) - } - } + messages.forEach { message -> _onMessage(message) } - for (message in messages) { - if (message is JSONRPCRequest) { - streamMapping[streamId] = session - callMapping[streamId] = call - requestToStreamMapping[message.id] = streamId - } - } - } - for (message in messages) { - _onMessage.invoke(message) - } } catch (e: Exception) { - call.response.status(HttpStatusCode.BadRequest) - call.respond( - JSONRPCResponse( - id = null, - error = JSONRPCError( - code = ErrorCode.Unknown(-32000), - message = e.message ?: "Parse error" - ) - ) + call.reject( + HttpStatusCode.BadRequest, + ErrorCode.Defined.ParseError, + "Parse error: ${e.message}" ) - _onError.invoke(e) + _onError(e) } } - public suspend fun handleGetRequest(call: ApplicationCall, session: ServerSSESession) { - val acceptHeader = call.request.headers["Accept"]?.split(",") ?: listOf() - if (!acceptHeader.contains("text/event-stream")) { - call.response.status(HttpStatusCode.NotAcceptable) - call.respond( - JSONRPCResponse( - id = null, - error = JSONRPCError( - code = ErrorCode.Unknown(-32000), - message = "Not Acceptable: Client must accept text/event-stream" - ) - ) + public suspend fun handleGetRequest(call: ApplicationCall, sse: ServerSSESession) { + val acceptHeader = call.request.headers[HttpHeaders.Accept] + if (!acceptHeader.accepts(ContentType.Text.EventStream)) { + call.reject( + HttpStatusCode.NotAcceptable, ErrorCode.Unknown(-32000), + "Not Acceptable: Client must accept text/event-stream" ) - } - - if (!validateSession(call)) { return } - if (sessionId != null) { - call.response.header(MCP_SESSION_ID, sessionId!!) + if (!validateSession(call) || !validateProtocolVersion(call)) return + + eventStore?.let { store -> + call.request.headers[LAST_EVENT_ID]?.let { lastEventId -> + replayEvents(store, lastEventId, call, sse) + return + } } - if (streamMapping[standalone] != null) { - call.response.status(HttpStatusCode.Conflict) - call.respond( - JSONRPCResponse( - id = null, - error = JSONRPCError( - code = ErrorCode.Unknown(-32000), - message = "Conflict: Only one SSE stream is allowed per session" - ) - ) + if (STANDALONE_SSE_STREAM_ID in streams) { + call.reject( + HttpStatusCode.Conflict, ErrorCode.Unknown(-32000), + "Conflict: Only one SSE stream is allowed per session" ) - session.close() return } - // TODO: Equivalent of typescript res.writeHead(200, headers).flushHeaders(); - streamMapping[standalone] = session + call.appendSseHeaders() + sse.send(data = "") + streams[STANDALONE_SSE_STREAM_ID] = ActiveStream(sse, call) + sse.coroutineContext.job.invokeOnCompletion { streams -= STANDALONE_SSE_STREAM_ID } } public suspend fun handleDeleteRequest(call: ApplicationCall) { - if (!validateSession(call)) { - return - } + if (!validateSession(call) || !validateProtocolVersion(call)) return close() - call.respondNullable(HttpStatusCode.OK) + call.respondNullable(status = HttpStatusCode.OK, message = null) } - private suspend fun validateSession(call: ApplicationCall): Boolean { - if (sessionId == null) { - return true + private suspend fun replayEvents( + store: EventStore, + lastId: String, + call: ApplicationCall, + session: ServerSSESession + ) { + try { + call.appendSseHeaders() + val streamId = store.replayEventsAfter(lastId) { eventId, message -> + try { + session.send( + event = "message", + id = eventId, + data = McpJson.encodeToString(message) + ) + } catch (e: Exception) { + _onError(e) + } + } + streams[streamId] = ActiveStream(session, call) + } catch (e: Exception) { + _onError(e) } + } + + private suspend fun validateSession(call: ApplicationCall): Boolean { + if (sessionIdGenerator == null) return true if (!initialized.load()) { - call.response.status(HttpStatusCode.BadRequest) - call.respond( - JSONRPCResponse( - id = null, - error = JSONRPCError( - code = ErrorCode.Unknown(-32000), - message = "Bad Request: Server not initialized" - ) - ) + call.reject( + HttpStatusCode.BadRequest, ErrorCode.Unknown(-32000), + "Bad Request: Server not initialized" ) return false } - return true - } - private suspend fun validateHeaders(call: ApplicationCall): Boolean { - val acceptHeader = call.request.headers["Accept"]?.split(",") ?: listOf() - - if (!acceptHeader.contains("text/event-stream") || !acceptHeader.contains("application/json")) { - call.response.status(HttpStatusCode.NotAcceptable) - call.respond( - JSONRPCResponse( - id = null, - error = JSONRPCError( - code = ErrorCode.Unknown(-32000), - message = "Not Acceptable: Client must accept both application/json and text/event-stream" - ) + val headerId = call.request.headers[MCP_SESSION_ID] + + return when { + headerId == null -> { + call.reject( + HttpStatusCode.BadRequest, ErrorCode.Unknown(-32000), + "Bad Request: Mcp-Session-Id header is required" ) - ) - return false + false + } + + headerId != sessionId -> { + call.reject( + HttpStatusCode.NotFound, ErrorCode.Unknown(-32001), + "Session not found" + ) + return false + } + + else -> true } + } + + private suspend fun validateProtocolVersion(call: ApplicationCall): Boolean { + val version = call.request.headers[MCP_PROTOCOL_VERSION] ?: LATEST_PROTOCOL_VERSION - val contentType = call.request.contentType() - if (contentType != ContentType.Application.Json) { - call.response.status(HttpStatusCode.UnsupportedMediaType) - call.respond( - JSONRPCResponse( - id = null, - error = JSONRPCError( - code = ErrorCode.Unknown(-32000), - message = "Unsupported Media Type: Content-Type must be application/json" + return if (version !in SUPPORTED_PROTOCOL_VERSIONS) { + call.reject( + HttpStatusCode.BadRequest, ErrorCode.Unknown(-32000), + "Bad Request: Unsupported protocol version (supported versions: ${ + SUPPORTED_PROTOCOL_VERSIONS.joinToString( + ", " ) - ) + })" ) - return false + false + } else { + true } + } - return true + private fun validateHeaders(call: ApplicationCall): String? { + if (!enableDnsRebindingProtection) return null + allowedHosts?.let { hosts -> + val hostHeader = call.request.host().substringBefore(':').lowercase() + if (hostHeader !in hosts.map { it.substringBefore(':').lowercase() }) { + return "Invalid Host header: $hostHeader" + } + } + allowedOrigins?.let { origins -> + val originHeader = call.request.headers[HttpHeaders.Origin]?.removeSuffix("/")?.lowercase() + if (originHeader !in origins.map { it.removeSuffix("/").lowercase() }) { + return "Invalid Origin header: $originHeader" + } + } + return null } - private suspend fun parseBody( - call: ApplicationCall, - ): List { - val messages = mutableListOf() - when (val body = call.receive()) { - is JsonObject -> messages.add(McpJson.decodeFromJsonElement(body)) - is JsonArray -> messages.addAll(McpJson.decodeFromJsonElement>(body)) + private suspend fun parseBody(call: ApplicationCall): List? { + return when (val element = call.receive()) { + is JsonObject -> listOf(McpJson.decodeFromJsonElement(element)) + is JsonArray -> McpJson.decodeFromJsonElement(element) else -> { - call.response.status(HttpStatusCode.BadRequest) - call.respond( - JSONRPCResponse( - id = null, - error = JSONRPCError( - code = ErrorCode.Defined.InvalidRequest, - message = "Invalid Request: Server already initialized" - ) - ) + call.reject( + HttpStatusCode.BadRequest, ErrorCode.Defined.ParseError, + "Invalid JSON format" ) - return listOf() + null } } - return messages } + private fun String?.accepts(mime: ContentType): Boolean { + if (this == null) return false + + val escaped = Regex.escape(mime.toString()) + val pattern = Regex("""(^|,\s*)$escaped(\s*;|$)""", RegexOption.IGNORE_CASE) + return pattern.containsMatchIn(this) + } + private suspend fun emitOnStream(streamId: String, active: ActiveStream, message: JSONRPCMessage) { + val eventId = eventStore?.storeEvent(streamId, message) + try { + active.sse.send(event = "message", id = eventId, data = McpJson.encodeToString(message)) + } catch (_: Exception) { + streams.remove(streamId) + } + } + + private suspend fun ApplicationCall.reject(status: HttpStatusCode, code: ErrorCode, message: String) { + this.response.status(status) + this.respond(JSONRPCResponse(id = null, error = JSONRPCError(code = code, message = message))) + } + + private fun ApplicationCall.appendSseHeaders() { + this.response.headers.append(HttpHeaders.ContentType, ContentType.Text.EventStream.toString()) + this.response.headers.append(HttpHeaders.CacheControl, "no-cache, no-transform") + this.response.headers.append(HttpHeaders.Connection, "keep-alive") + this.response.headers.append("X-Accel-Buffering", "no") + sessionId?.let { this.response.header(MCP_SESSION_ID, it) } + this.response.status(HttpStatusCode.OK) + } + + private companion object { + const val STANDALONE_SSE_STREAM_ID = "_GET_stream" + } } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Constants.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Constants.kt deleted file mode 100644 index 727aebec..00000000 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Constants.kt +++ /dev/null @@ -1,3 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.shared - -internal const val MCP_SESSION_ID = "mcp-session-id" \ No newline at end of file diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 8ad733b1..d9b77089 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -248,7 +248,7 @@ public abstract class Protocol( JSONRPCResponse( id = request.id, error = JSONRPCError( - ErrorCode.Defined.MethodNotFound, + code = ErrorCode.Defined.MethodNotFound, message = "Server does not support ${request.method}", ) ) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt index f7dbc797..c1bc7815 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt @@ -15,10 +15,11 @@ import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.concurrent.atomics.incrementAndFetch import kotlin.jvm.JvmInline -public const val LATEST_PROTOCOL_VERSION: String = "2024-11-05" +public const val LATEST_PROTOCOL_VERSION: String = "2025-03-26" public val SUPPORTED_PROTOCOL_VERSIONS: Array = arrayOf( LATEST_PROTOCOL_VERSION, + "2024-11-05", "2024-10-07", ) @@ -276,10 +277,11 @@ public sealed interface ErrorCode { * A response to a request that indicates an error occurred. */ @Serializable -public data class JSONRPCError( - val code: ErrorCode, - val message: String, - val data: JsonObject = EmptyJsonObject, +public class JSONRPCError( + public val id: RequestId? = null, + public val code: ErrorCode, + public val message: String, + public val data: JsonObject = EmptyJsonObject, ) : JSONRPCMessage /* Cancellation */ From abebd98a7bafef1350d72f62227366eef66be655 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Fri, 11 Jul 2025 15:15:59 +0200 Subject: [PATCH 6/8] Fix accepts headers and workaround for receive Json --- .../server/StreamableHttpServerTransport.kt | 22 +++++++++++++------ .../kotlin/integration/SseIntegrationTest.kt | 0 2 files changed, 15 insertions(+), 7 deletions(-) delete mode 100644 src/jvmTest/kotlin/integration/SseIntegrationTest.kt diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 143f4d24..e0acd3cc 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -9,6 +9,7 @@ import io.ktor.server.request.contentType import io.ktor.server.request.host import io.ktor.server.request.httpMethod import io.ktor.server.request.receive +import io.ktor.server.request.receiveText import io.ktor.server.response.header import io.ktor.server.response.respond import io.ktor.server.response.respondNullable @@ -440,15 +441,22 @@ public class StreamableHttpServerTransport( } private suspend fun parseBody(call: ApplicationCall): List? { - return when (val element = call.receive()) { + val body = call.receiveText() + return when (val element = McpJson.parseToJsonElement(body)) { is JsonObject -> listOf(McpJson.decodeFromJsonElement(element)) - is JsonArray -> McpJson.decodeFromJsonElement(element) + is JsonArray -> McpJson.decodeFromJsonElement>(element) else -> { - call.reject( - HttpStatusCode.BadRequest, ErrorCode.Defined.ParseError, - "Invalid JSON format" + call.response.status(HttpStatusCode.BadRequest) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Defined.InvalidRequest, + message = "Invalid Request: Server already initialized" + ) + ) ) - null + return null } } } @@ -457,7 +465,7 @@ public class StreamableHttpServerTransport( if (this == null) return false val escaped = Regex.escape(mime.toString()) - val pattern = Regex("""(^|,\s*)$escaped(\s*;|$)""", RegexOption.IGNORE_CASE) + val pattern = Regex("""(^|,\s*)$escaped(\s*(;|,|$))""", RegexOption.IGNORE_CASE) return pattern.containsMatchIn(this) } diff --git a/src/jvmTest/kotlin/integration/SseIntegrationTest.kt b/src/jvmTest/kotlin/integration/SseIntegrationTest.kt deleted file mode 100644 index e69de29b..00000000 From 9613cd32bb833e614f8c805e98fcad2c292342a7 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Fri, 11 Jul 2025 16:15:35 +0200 Subject: [PATCH 7/8] Fix content type check in StreamableHttpServerTransport --- .../kotlin/sdk/server/StreamableHttpServerTransport.kt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index e0acd3cc..80599ec6 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -8,7 +8,6 @@ import io.ktor.server.application.ApplicationCall import io.ktor.server.request.contentType import io.ktor.server.request.host import io.ktor.server.request.httpMethod -import io.ktor.server.request.receive import io.ktor.server.request.receiveText import io.ktor.server.response.header import io.ktor.server.response.respond @@ -30,7 +29,6 @@ import kotlinx.coroutines.job import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.serialization.json.JsonArray -import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.decodeFromJsonElement import kotlin.concurrent.atomics.AtomicBoolean @@ -237,7 +235,7 @@ public class StreamableHttpServerTransport( return } - if (call.request.contentType() != ContentType.Application.Json) { + if (!call.request.contentType().match(ContentType.Application.Json)) { call.reject( HttpStatusCode.UnsupportedMediaType, ErrorCode.Unknown(-32000), "Unsupported Media Type: Content-Type must be application/json" From a9449be9f53e5c745f3afc995699f590216fee6a Mon Sep 17 00:00:00 2001 From: Pavel Gorgulov Date: Fri, 11 Jul 2025 19:53:38 +0200 Subject: [PATCH 8/8] Update src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../kotlin/sdk/server/StreamableHttpServerTransport.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 80599ec6..6cd08ba7 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -450,7 +450,7 @@ public class StreamableHttpServerTransport( id = null, error = JSONRPCError( code = ErrorCode.Defined.InvalidRequest, - message = "Invalid Request: Server already initialized" + message = "Invalid Request: unable to parse JSON body" ) ) )