Skip to content

BAEL-9293 - Securing Spring AI MCP servers with OAuth2 #18630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions spring-ai-3/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-mcp-server-webmvc-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.hsqldb</groupId>
<artifactId>hsqldb</artifactId>
Expand All @@ -61,6 +65,16 @@
<artifactId>spring-ai-starter-model-openai</artifactId>
<version>${spring-ai-start-model-openai.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-oauth2-resource-server</artifactId>
<version>${oauth2-resource-server.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-oauth2-authorization-server</artifactId>
<version>${oauth2-authorization-server.version}</version>
</dependency>

<!-- Test dependencies -->
<dependency>
Expand Down Expand Up @@ -146,6 +160,8 @@
<spring-boot.version>3.4.5</spring-boot.version>
<spring-ai.version>1.0.0-M6</spring-ai.version>
<spring-ai-start-model-openai.version>1.0.0-M7</spring-ai-start-model-openai.version>
<oauth2-resource-server.version>3.4.2</oauth2-resource-server.version>
<oauth2-authorization-server.version>3.3.3</oauth2-authorization-server.version>
</properties>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.baeldung.springai.mcp.oauth2;

import org.springframework.ai.autoconfigure.chat.client.ChatClientAutoConfiguration;
import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.model.openai.autoconfigure.*;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.data.mongo.MongoDataAutoConfiguration;
import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration;

@SpringBootApplication(exclude = {
ChatClientAutoConfiguration.class,
MongoAutoConfiguration.class,
MistralAiAutoConfiguration.class,
MongoDataAutoConfiguration.class,
org.springframework.ai.autoconfigure.vectorstore.mongo.MongoDBAtlasVectorStoreAutoConfiguration.class,
org.springframework.ai.vectorstore.mongodb.autoconfigure.MongoDBAtlasVectorStoreAutoConfiguration.class,
OpenAiAudioSpeechAutoConfiguration.class,
OpenAiAutoConfiguration.class,
OpenAiAudioTranscriptionAutoConfiguration.class,
OpenAiChatAutoConfiguration.class,
OpenAiEmbeddingAutoConfiguration.class,
OpenAiImageAutoConfiguration.class,
OpenAiModerationAutoConfiguration.class})
class McpServerApplication {

public static void main(String[] args) {
SpringApplication app = new SpringApplication(McpServerApplication.class);
app.setAdditionalProfiles("mcp");
app.run(args);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.baeldung.springai.mcp.oauth2;

import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;

public class StockInformationHolder {
@Tool(description = "Get stock price for a company symbol")
public String getStockPrice(@ToolParam String symbol) {
if ("AAPL".equalsIgnoreCase(symbol)) {
return "AAPL: $150.00";
} else if ("GOOGL".equalsIgnoreCase(symbol)) {
return "GOOGL: $2800.00";
} else {
return symbol + ": Data not available";
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.baeldung.springai.mcp.oauth2.configuration;

import com.baeldung.springai.mcp.oauth2.StockInformationHolder;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Profile;

@Profile("mcp")
@Configuration
public class McpServerConfiguration {

@Bean
public ToolCallbackProvider stockTools() {
return MethodToolCallbackProvider
.builder()
.toolObjects(new StockInformationHolder())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.baeldung.springai.mcp.oauth2.configuration;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
import org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers.OAuth2AuthorizationServerConfigurer;
import org.springframework.security.web.SecurityFilterChain;

@Configuration
@EnableWebSecurity
public class McpServerSecurityConfiguration {
@Bean
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
return http
.authorizeHttpRequests(auth -> auth
.requestMatchers("/mcp/**").authenticated()
.requestMatchers("/sse").authenticated()
.anyRequest().permitAll())
.with(OAuth2AuthorizationServerConfigurer.authorizationServer(), Customizer.withDefaults())
.oauth2ResourceServer(oauth2 -> oauth2.jwt(Customizer.withDefaults()))
.csrf(CsrfConfigurer::disable)
.cors(Customizer.withDefaults())
.build();
}
}
19 changes: 19 additions & 0 deletions spring-ai-3/src/main/resources/application-mcp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
spring:
security:
oauth2:
authorizationserver:
client:
oidc-client:
registration:
client-id: mcp-client
client-secret: "{noop}secret"
client-authentication-methods: client_secret_basic
authorization-grant-types: client_credentials
# Avoid starting docker from the shared codebase
docker:
compose:
enabled: false

logging:
level:
org.springframework.ai.mcp: DEBUG
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package com.baeldung.springai.mcp.oauth2;

import com.fasterxml.jackson.databind.JsonNode;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;

import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;

import static org.assertj.core.api.Assertions.assertThat;

@ActiveProfiles("mcp")
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
class McpServerOAuth2LiveTest {

private static final Logger log = LoggerFactory.getLogger(McpServerOAuth2LiveTest.class);

@LocalServerPort
private int port;

private WebClient webClient;

@BeforeEach
void setup() {
webClient = WebClient.create("http://localhost:" + port);
}

@Test
void givenSecuredMcpServer_whenCallingTheEndpointsWithValidAuthorizationHeader_thenExpectedResponseShouldBeObtained() {
Flux<String> eventStream = webClient.get()
.uri("/sse")
.header("Authorization", obtainAccessToken())
.accept(MediaType.TEXT_EVENT_STREAM)
.retrieve()
.bodyToFlux(String.class);

eventStream.subscribe(
data -> {
log.info("Response received: {}", data);
if(!isRequestMessage(data)) {
assertThat(data).containsSequence("AAPL", "$150");
}
},
error -> log.error(error.getMessage()),
() -> log.info("Stream completed"));

Flux<String> sendMessage = webClient.post()
.uri("/mcp/message")
.header("Authorization", obtainAccessToken())
.contentType(MediaType.APPLICATION_JSON)
.accept(MediaType.TEXT_EVENT_STREAM)
.bodyValue("""
{
"jsonrpc": "2.0",
"id": "1",
"method": "tools/call",
"params": {
"name": "getStockPrice",
"arguments": {
"arg0": "AAPL"
}
}
}
""")
.retrieve()
.bodyToFlux(String.class);

sendMessage.blockLast();
eventStream.blockLast();
}

private boolean isRequestMessage(String data) {
return data.contains("/mcp/message");
}

public String obtainAccessToken() {
String clientId = "mcp-client";
String clientSecret = "secret";
String basicToken = Base64.getEncoder()
.encodeToString((clientId + ":" + clientSecret).getBytes(StandardCharsets.UTF_8));

return "Bearer " + webClient.post()
.uri("/oauth2/token")
.header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE)
.header(HttpHeaders.AUTHORIZATION, "Basic " + basicToken)
.body(BodyInserters
.fromFormData("grant_type", "client_credentials")
)
.retrieve()
.bodyToMono(JsonNode.class)
.map(node -> node.get("access_token").asText())
.block(Duration.ofSeconds(5));
}
}