diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt index 290c081..9b17846 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt @@ -12,6 +12,7 @@ import io.ktor.http.HttpHeaders import io.ktor.http.Url import io.ktor.http.append import io.ktor.http.isSuccess +import io.ktor.http.protocolWithAuthority import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.McpJson @@ -24,7 +25,6 @@ import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.launch -import kotlinx.serialization.encodeToString import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.properties.Delegates @@ -55,7 +55,18 @@ public class SseClientTransport( private var job: Job? = null private val baseUrl by lazy { - session.call.request.url.toString().removeSuffix("/sse") + val requestUrl = session.call.request.url.toString() + val url = Url(requestUrl) + var path = url.encodedPath + if (path.isEmpty()) { + url.protocolWithAuthority + } else if (path.endsWith("/")) { + url.protocolWithAuthority + path.removeSuffix("/") + } else { + // the last item is not a directory, so will not be taken into account + path = path.substring(0, path.lastIndexOf("/")) + url.protocolWithAuthority + path + } } override suspend fun start() { @@ -95,8 +106,7 @@ public class SseClientTransport( val eventData = event.data ?: "" // check url correctness - val maybeEndpoint = Url(baseUrl + eventData) - + val maybeEndpoint = Url("$baseUrl/${if (eventData.startsWith("/")) eventData.substring(1) else eventData}") endpoint.complete(maybeEndpoint.toString()) } catch (e: Exception) { _onError(e) diff --git a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt index fd4c3d6..de61a07 100644 --- a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt +++ b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt @@ -6,6 +6,7 @@ import io.ktor.server.application.install import io.ktor.server.cio.CIO import io.ktor.server.engine.embeddedServer import io.ktor.server.routing.post +import io.ktor.server.routing.route import io.ktor.server.routing.routing import io.ktor.server.sse.sse import io.ktor.util.collections.ConcurrentMap @@ -82,4 +83,43 @@ class SseTransportTest : BaseTransportTest() { testClientRead(client) server.stopSuspend() } + + @Test + fun `test sse path not root path`() = runTest { + val port = 3007 + val server = embeddedServer(CIO, port = port) { + install(io.ktor.server.sse.SSE) + val transports = ConcurrentMap() + routing { + route("/sse") { + sse { + mcpSseTransport("", transports).apply { + onMessage { + send(it) + } + + start() + } + } + + post { + mcpPostEndpoint(transports) + } + } + } + }.startSuspend(wait = false) + + val client = HttpClient { + install(SSE) + }.mcpSseTransport { + url { + host = "localhost" + this.port = port + pathSegments = listOf("sse") + } + } + + testClientRead(client) + server.stopSuspend() + } }