diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index d30aa1763ea..c51ed789850 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -695,6 +695,8 @@ public class OAuth2LoginSpec { private ServerWebExchangeMatcher authenticationMatcher; + private ServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); + /** * Configures the {@link ReactiveAuthenticationManager} to use. The default is * {@link OAuth2AuthorizationCodeReactiveAuthenticationManager} @@ -706,6 +708,18 @@ public OAuth2LoginSpec authenticationManager(ReactiveAuthenticationManager authe return this; } + /** + * The {@link ServerAuthenticationSuccessHandler} used after authentication success. Defaults to + * {@link RedirectServerAuthenticationSuccessHandler} redirecting to "/". + * @param authenticationSuccessHandler the success handler to use + * @return the {@link OAuth2LoginSpec} to continue configuring + */ + public OAuth2LoginSpec authenticationSuccessHandler(ServerAuthenticationSuccessHandler authenticationSuccessHandler) { + Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); + this.authenticationSuccessHandler = authenticationSuccessHandler; + return this; + } + /** * Gets the {@link ReactiveAuthenticationManager} to use. First tries an explicitly configured manager, and * defaults to {@link OAuth2AuthorizationCodeReactiveAuthenticationManager} @@ -821,9 +835,8 @@ protected void configure(ServerHttpSecurity http) { AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, authorizedClientRepository); authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher()); authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository)); - RedirectServerAuthenticationSuccessHandler redirectHandler = new RedirectServerAuthenticationSuccessHandler(); - authenticationFilter.setAuthenticationSuccessHandler(redirectHandler); + authenticationFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler); authenticationFilter.setAuthenticationFailureHandler(new ServerAuthenticationFailureHandler() { @Override public Mono onAuthenticationFailure(WebFilterExchange webFilterExchange, diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java index 6857bd6ce53..44bb47bd3e3 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -23,7 +23,11 @@ import org.junit.Rule; import org.junit.Test; +import org.mockito.stubbing.Answer; import org.openqa.selenium.WebDriver; +import org.springframework.security.web.server.WebFilterExchange; +import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler; +import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; import reactor.core.publisher.Mono; import org.springframework.beans.factory.annotation.Autowired; @@ -184,6 +188,8 @@ public void oauth2LoginWhenCustomObjectsThenUsed() { this.spring.register(OAuth2LoginWithSingleClientRegistrations.class, OAuth2LoginMockAuthenticationManagerConfig.class).autowire(); + String redirectLocation = "/custom-redirect-location"; + WebTestClient webTestClient = WebTestClientBuilder .bindToWebFilters(this.springSecurity) .build(); @@ -194,6 +200,7 @@ public void oauth2LoginWhenCustomObjectsThenUsed() { ReactiveAuthenticationManager manager = config.manager; ServerWebExchangeMatcher matcher = config.matcher; ServerOAuth2AuthorizationRequestResolver resolver = config.resolver; + ServerAuthenticationSuccessHandler successHandler = config.successHandler; OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success(); OAuth2User user = TestOAuth2Users.create(); @@ -205,16 +212,25 @@ public void oauth2LoginWhenCustomObjectsThenUsed() { when(manager.authenticate(any())).thenReturn(Mono.just(result)); when(matcher.matches(any())).thenReturn(ServerWebExchangeMatcher.MatchResult.match()); when(resolver.resolve(any())).thenReturn(Mono.empty()); + when(successHandler.onAuthenticationSuccess(any(), any())).thenAnswer((Answer>) invocation -> { + WebFilterExchange webFilterExchange = invocation.getArgument(0); + Authentication authentication = invocation.getArgument(1); + + return new RedirectServerAuthenticationSuccessHandler(redirectLocation) + .onAuthenticationSuccess(webFilterExchange, authentication); + }); webTestClient.get() .uri("/login/oauth2/code/github") .exchange() - .expectStatus().is3xxRedirection(); + .expectStatus().is3xxRedirection() + .expectHeader().valueEquals("Location", redirectLocation); verify(converter).convert(any()); verify(manager).authenticate(any()); verify(matcher).matches(any()); verify(resolver).resolve(any()); + verify(successHandler).onAuthenticationSuccess(any(), any()); } @Configuration @@ -227,6 +243,8 @@ static class OAuth2LoginMockAuthenticationManagerConfig { ServerOAuth2AuthorizationRequestResolver resolver = mock(ServerOAuth2AuthorizationRequestResolver.class); + ServerAuthenticationSuccessHandler successHandler = mock(ServerAuthenticationSuccessHandler.class); + @Bean public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { http @@ -237,7 +255,8 @@ public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { .authenticationConverter(authenticationConverter) .authenticationManager(manager) .authenticationMatcher(matcher) - .authorizationRequestResolver(resolver); + .authorizationRequestResolver(resolver) + .authenticationSuccessHandler(successHandler); return http.build(); } } @@ -425,4 +444,5 @@ Mono authentication(Authentication authentication) { T getBean(Class beanClass) { return this.spring.getContext().getBean(beanClass); } + }