diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 13c9a1b3c07..0bdd2bfe340 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,8 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.annotation.AnnotationBeanNameGenerator; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -160,7 +162,7 @@ private OAuth2AuthorizedClientManager getAuthorizedClientManager() { * @since 6.2.0 */ static final class OAuth2AuthorizedClientManagerRegistrar - implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware { + implements ApplicationEventPublisherAware, BeanDefinitionRegistryPostProcessor, BeanFactoryAware { static final String BEAN_NAME = "authorizedClientManagerRegistrar"; @@ -179,6 +181,8 @@ static final class OAuth2AuthorizedClientManagerRegistrar private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator(); + private ApplicationEventPublisher applicationEventPublisher; + private ListableBeanFactory beanFactory; @Override @@ -302,6 +306,10 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider( authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); } + if (this.applicationEventPublisher != null) { + authorizedClientProvider.setApplicationEventPublisher(this.applicationEventPublisher); + } + return authorizedClientProvider; } @@ -423,6 +431,11 @@ private T getBeanOfType(ResolvableType resolvableType) { return objectProvider.getIfAvailable(); } + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { + this.applicationEventPublisher = applicationEventPublisher; + } + } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index ad93db75d3d..dcacd52df18 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -57,6 +57,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider; +import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizedClientRefreshedEventListener; import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry; import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry; @@ -90,6 +91,7 @@ import org.springframework.security.web.authentication.session.SessionAuthenticationException; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.util.matcher.AndRequestMatcher; @@ -386,14 +388,26 @@ public void init(B http) throws Exception { OAuth2UserService oidcUserService = getOidcUserService(); OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider = new OidcAuthorizationCodeAuthenticationProvider( accessTokenResponseClient, oidcUserService); + OidcAuthorizedClientRefreshedEventListener oidcAuthorizedClientRefreshedEventListener = new OidcAuthorizedClientRefreshedEventListener(); + oidcAuthorizedClientRefreshedEventListener.setUserService(oidcUserService); + oidcAuthorizedClientRefreshedEventListener + .setApplicationEventPublisher(http.getSharedObject(ApplicationContext.class)); + JwtDecoderFactory jwtDecoderFactory = this.getJwtDecoderFactoryBean(); if (jwtDecoderFactory != null) { oidcAuthorizationCodeAuthenticationProvider.setJwtDecoderFactory(jwtDecoderFactory); + oidcAuthorizedClientRefreshedEventListener.setJwtDecoderFactory(jwtDecoderFactory); } if (userAuthoritiesMapper != null) { oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper); + oidcAuthorizedClientRefreshedEventListener.setAuthoritiesMapper(userAuthoritiesMapper); } - http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider)); + oidcAuthorizationCodeAuthenticationProvider = this.postProcess(oidcAuthorizationCodeAuthenticationProvider); + http.authenticationProvider(oidcAuthorizationCodeAuthenticationProvider); + + oidcAuthorizedClientRefreshedEventListener = this.postProcess(oidcAuthorizedClientRefreshedEventListener); + registerDelegateApplicationListener(oidcAuthorizedClientRefreshedEventListener); + configureOidcUserRefreshedEventListener(http); } else { http.authenticationProvider(new OidcAuthenticationRequestChecker()); @@ -621,6 +635,16 @@ private void configureOidcSessionRegistry(B http) { registerDelegateApplicationListener(listener); } + private void configureOidcUserRefreshedEventListener(B http) { + OidcUserRefreshedEventListener oidcUserRefreshedEventListener = new OidcUserRefreshedEventListener(); + oidcUserRefreshedEventListener.setSecurityContextHolderStrategy(this.getSecurityContextHolderStrategy()); + SecurityContextRepository securityContextRepository = http.getSharedObject(SecurityContextRepository.class); + if (securityContextRepository != null) { + oidcUserRefreshedEventListener.setSecurityContextRepository(securityContextRepository); + } + registerDelegateApplicationListener(oidcUserRefreshedEventListener); + } + private void registerDelegateApplicationListener(ApplicationListener delegate) { DelegatingApplicationListener delegating = getBeanOrNull( ResolvableType.forType(DelegatingApplicationListener.class)); diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListener.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListener.java new file mode 100644 index 00000000000..e9dce35b69e --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListener.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configurers.oauth2.client; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.springframework.context.ApplicationListener; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.oauth2.client.oidc.authentication.event.OidcUserRefreshedEvent; +import org.springframework.security.web.context.HttpSessionSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.util.Assert; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +/** + * An {@link ApplicationListener} that listens for events of type + * {@link OidcUserRefreshedEvent} and refreshes the {@link SecurityContext}. + * + * @author Steve Riesenberg + * @since 6.5 + * @see org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider + */ +final class OidcUserRefreshedEventListener implements ApplicationListener { + + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + + private SecurityContextRepository securityContextRepository = new HttpSessionSecurityContextRepository(); + + @Override + public void onApplicationEvent(OidcUserRefreshedEvent event) { + SecurityContext securityContext = this.securityContextHolderStrategy.createEmptyContext(); + securityContext.setAuthentication(event.getAuthentication()); + this.securityContextHolderStrategy.setContext(securityContext); + + RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); + if (!(requestAttributes instanceof ServletRequestAttributes servletRequestAttributes)) { + return; + } + + HttpServletRequest request = servletRequestAttributes.getRequest(); + HttpServletResponse response = servletRequestAttributes.getResponse(); + this.securityContextRepository.saveContext(securityContext, request, response); + } + + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * @param securityContextHolderStrategy the {@link SecurityContextHolderStrategy} to + * use + */ + void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} upon + * receiving an {@link OidcUserRefreshedEvent}. + * @param securityContextRepository the {@link SecurityContextRepository} to use + */ + void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + +} diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java b/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java index 669d6f7f67f..dc131adcd45 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.annotation.AnnotationBeanNameGenerator; import org.springframework.core.ResolvableType; import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; @@ -197,6 +198,12 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider( authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); } + ApplicationEventPublisher applicationEventPublisher = getBeanOfType( + ResolvableType.forClass(ApplicationEventPublisher.class)); + if (applicationEventPublisher != null) { + authorizedClientProvider.setApplicationEventPublisher(applicationEventPublisher); + } + return authorizedClientProvider; } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java new file mode 100644 index 00000000000..5ad7f9726e2 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java @@ -0,0 +1,504 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configurers.oauth2.client; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import jakarta.servlet.http.HttpServletRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.config.Customizer; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; +import org.springframework.security.config.test.SpringTestContext; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtDecoderFactory; +import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Tests for {@link OidcUserRefreshedEventListener} with {@link OAuth2LoginConfigurer}. + * + * @author Steve Riesenberg + */ +public class OidcUserRefreshedEventListenerConfigurationTests { + + // @formatter:off + private static final ClientRegistration GOOGLE_CLIENT_REGISTRATION = CommonOAuth2Provider.GOOGLE + .getBuilder("google") + .clientId("clientId") + .clientSecret("clientSecret") + .build(); + // @formatter:on + + // @formatter:off + private static final ClientRegistration GITHUB_CLIENT_REGISTRATION = CommonOAuth2Provider.GITHUB + .getBuilder("github") + .clientId("clientId") + .clientSecret("clientSecret") + .build(); + // @formatter:on + + private static final String SUBJECT = "surfer-dude"; + + private static final String ACCESS_TOKEN_VALUE = "hang-ten"; + + private static final String REFRESH_TOKEN_VALUE = "surfs-up"; + + private static final String ID_TOKEN_VALUE = "beach-break"; + + public final SpringTestContext spring = new SpringTestContext(this); + + @Autowired + private SecurityContextRepository securityContextRepository; + + @Autowired + private OAuth2AuthorizedClientRepository authorizedClientRepository; + + @Autowired + private OAuth2AccessTokenResponseClient refreshTokenAccessTokenResponseClient; + + @Autowired + private JwtDecoder jwtDecoder; + + @Autowired + private OidcUserService oidcUserService; + + @Autowired + private OAuth2AuthorizedClientManager authorizedClientManager; + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + @BeforeEach + public void setUp() { + this.request = new MockHttpServletRequest("GET", ""); + this.request.setServletPath("/"); + this.response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.request, this.response)); + } + + @AfterEach + public void cleanUp() { + SecurityContextHolder.clearContext(); + RequestContextHolder.resetRequestAttributes(); + } + + @Test + public void authorizeWhenAccessTokenResponseMissingOpenidScopeThenOidcUserNotRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = createAccessTokenResponse(); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + OAuth2AuthenticationToken authentication = createAuthenticationToken(GOOGLE_CLIENT_REGISTRATION); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + verifyNoInteractions(this.securityContextRepository, this.jwtDecoder, this.oidcUserService); + } + + @Test + public void authorizeWhenAccessTokenResponseMissingIdTokenThenOidcUserNotRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.oidcAccessTokenResponse() + .build(); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + OAuth2AuthenticationToken authentication = createAuthenticationToken(GOOGLE_CLIENT_REGISTRATION); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + verifyNoInteractions(this.securityContextRepository, this.jwtDecoder, this.oidcUserService); + } + + @Test + public void authorizeWhenAuthenticationIsNotOAuth2ThenOidcUserNotRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = createAccessTokenResponse(OidcScopes.OPENID); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken(SUBJECT, null); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + SecurityContextHolder.setContext(securityContext); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + verifyNoInteractions(this.securityContextRepository, this.jwtDecoder, this.oidcUserService); + } + + @Test + public void authorizeWhenAuthenticationIsCustomThenOidcUserNotRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = createAccessTokenResponse(OidcScopes.OPENID); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + OidcUser oidcUser = createOidcUser(); + OAuth2AuthenticationToken authentication = new CustomOAuth2AuthenticationToken(oidcUser, + oidcUser.getAuthorities(), GOOGLE_CLIENT_REGISTRATION.getRegistrationId()); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + SecurityContextHolder.setContext(securityContext); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + verifyNoInteractions(this.securityContextRepository, this.jwtDecoder, this.oidcUserService); + } + + @Test + public void authorizeWhenPrincipalIsOAuth2UserThenOidcUserNotRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = createAccessTokenResponse(OidcScopes.OPENID); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + Map attributes = Map.of(StandardClaimNames.SUB, SUBJECT); + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("OAUTH2_USER"), attributes, + StandardClaimNames.SUB); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(oauth2User, + oauth2User.getAuthorities(), GOOGLE_CLIENT_REGISTRATION.getRegistrationId()); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + SecurityContextHolder.setContext(securityContext); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + verifyNoInteractions(this.securityContextRepository, this.jwtDecoder, this.oidcUserService); + } + + @Test + public void authorizeWhenAuthenticationClientRegistrationIdDoesNotMatchThenOidcUserNotRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = createAccessTokenResponse(OidcScopes.OPENID); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + OAuth2AuthenticationToken authentication = createAuthenticationToken(GITHUB_CLIENT_REGISTRATION); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + SecurityContextHolder.setContext(securityContext); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + verifyNoInteractions(this.securityContextRepository, this.jwtDecoder, this.oidcUserService); + } + + @Test + public void authorizeWhenAccessTokenResponseIncludesIdTokenThenOidcUserRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = createAccessTokenResponse(OidcScopes.OPENID); + Jwt jwt = createJwt(); + OidcUser oidcUser = createOidcUser(); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + given(this.jwtDecoder.decode(anyString())).willReturn(jwt); + given(this.oidcUserService.loadUser(any(OidcUserRequest.class))).willReturn(oidcUser); + + OAuth2AuthenticationToken authentication = createAuthenticationToken(GOOGLE_CLIENT_REGISTRATION); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + SecurityContextHolder.setContext(securityContext); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + assertThat(refreshedAuthorizedClient).isNotSameAs(authorizedClient); + assertThat(refreshedAuthorizedClient.getClientRegistration()).isEqualTo(GOOGLE_CLIENT_REGISTRATION); + assertThat(refreshedAuthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(refreshedAuthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); + + ArgumentCaptor refreshTokenGrantRequestCaptor = ArgumentCaptor + .forClass(OAuth2RefreshTokenGrantRequest.class); + ArgumentCaptor userRequestCaptor = ArgumentCaptor.forClass(OidcUserRequest.class); + ArgumentCaptor securityContextCaptor = ArgumentCaptor.forClass(SecurityContext.class); + verify(this.authorizedClientRepository).loadAuthorizedClient(GOOGLE_CLIENT_REGISTRATION.getRegistrationId(), + authentication, this.request); + verify(this.authorizedClientRepository).saveAuthorizedClient(refreshedAuthorizedClient, authentication, + this.request, this.response); + verify(this.refreshTokenAccessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestCaptor.capture()); + verify(this.jwtDecoder).decode(jwt.getTokenValue()); + verify(this.oidcUserService).loadUser(userRequestCaptor.capture()); + verify(this.securityContextRepository).saveContext(securityContextCaptor.capture(), eq(this.request), + eq(this.response)); + verifyNoMoreInteractions(this.authorizedClientRepository, this.jwtDecoder, this.oidcUserService, + this.securityContextRepository); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = refreshTokenGrantRequestCaptor.getValue(); + assertThat(refreshTokenGrantRequest.getClientRegistration()) + .isEqualTo(authorizedClient.getClientRegistration()); + assertThat(refreshTokenGrantRequest.getRefreshToken()).isEqualTo(authorizedClient.getRefreshToken()); + assertThat(refreshTokenGrantRequest.getAccessToken()).isEqualTo(authorizedClient.getAccessToken()); + + OidcUserRequest userRequest = userRequestCaptor.getValue(); + assertThat(userRequest.getClientRegistration()).isEqualTo(GOOGLE_CLIENT_REGISTRATION); + assertThat(userRequest.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(userRequest.getIdToken().getTokenValue()).isEqualTo(jwt.getTokenValue()); + + SecurityContext refreshedSecurityContext = securityContextCaptor.getValue(); + assertThat(refreshedSecurityContext).isNotNull(); + assertThat(refreshedSecurityContext).isNotSameAs(securityContext); + assertThat(refreshedSecurityContext).isSameAs(SecurityContextHolder.getContext()); + assertThat(refreshedSecurityContext.getAuthentication()).isInstanceOf(OAuth2AuthenticationToken.class); + assertThat(refreshedSecurityContext.getAuthentication()).isNotSameAs(authentication); + assertThat(refreshedSecurityContext.getAuthentication().getPrincipal()).isInstanceOf(OidcUser.class); + assertThat(refreshedSecurityContext.getAuthentication().getPrincipal()) + .isNotSameAs(authentication.getPrincipal()); + } + + private OAuth2AuthorizedClient createAuthorizedClient() { + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(30, ChronoUnit.SECONDS); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, ACCESS_TOKEN_VALUE, + issuedAt, expiresAt, Set.of(OidcScopes.OPENID)); + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(REFRESH_TOKEN_VALUE, issuedAt); + + return new OAuth2AuthorizedClient(GOOGLE_CLIENT_REGISTRATION, SUBJECT, accessToken, refreshToken); + } + + private OAuth2AccessTokenResponse createAccessTokenResponse(String... scope) { + Set scopes = Set.of(scope); + Map additionalParameters = new HashMap<>(); + if (scopes.contains(OidcScopes.OPENID)) { + additionalParameters.put(OidcParameterNames.ID_TOKEN, ID_TOKEN_VALUE); + } + + return OAuth2AccessTokenResponse.withToken(ACCESS_TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .scopes(scopes) + .refreshToken(REFRESH_TOKEN_VALUE) + .expiresIn(60L) + .additionalParameters(additionalParameters) + .build(); + } + + private Jwt createJwt() { + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.MINUTES); + return TestJwts.jwt() + .subject(SUBJECT) + .tokenValue(ID_TOKEN_VALUE) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .build(); + } + + private OidcUser createOidcUser() { + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.SUB, SUBJECT); + claims.put(IdTokenClaimNames.ISS, "issuer"); + claims.put(IdTokenClaimNames.AUD, List.of("audience1", "audience2")); + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.MINUTES); + OidcIdToken idToken = new OidcIdToken(ID_TOKEN_VALUE, issuedAt, expiresAt, claims); + + return new DefaultOidcUser(AuthorityUtils.createAuthorityList("OIDC_USER"), idToken); + } + + private OAuth2AuthenticationToken createAuthenticationToken(ClientRegistration clientRegistration) { + OidcUser oidcUser = createOidcUser(); + return new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), + clientRegistration.getRegistrationId()); + } + + @Configuration + @EnableWebSecurity + static class OAuth2LoginWithOAuth2ClientConfig { + + @Bean + SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeHttpRequests((authorize) -> authorize + .anyRequest().authenticated() + ) + .securityContext((securityContext) -> securityContext + .securityContextRepository(this.securityContextRepository()) + ) + .oauth2Login(Customizer.withDefaults()) + .oauth2Client(Customizer.withDefaults()); + // @formatter:on + return http.build(); + } + + @Bean + SecurityContextRepository securityContextRepository() { + return mock(SecurityContextRepository.class); + } + + @Bean + ClientRegistrationRepository clientRegistrationRepository() { + return mock(ClientRegistrationRepository.class); + } + + @Bean + OAuth2AuthorizedClientRepository authorizedClientRepository() { + return mock(OAuth2AuthorizedClientRepository.class); + } + + @Bean + @SuppressWarnings("unchecked") + OAuth2AccessTokenResponseClient refreshTokenAccessTokenResponseClient() { + return mock(OAuth2AccessTokenResponseClient.class); + } + + @Bean + JwtDecoder jwtDecoder() { + return mock(JwtDecoder.class); + } + + @Bean + JwtDecoderFactory jwtDecoderFactory() { + return (clientRegistration) -> jwtDecoder(); + } + + @Bean + OidcUserService oidcUserService() { + return mock(OidcUserService.class); + } + + } + + private static final class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken { + + CustomOAuth2AuthenticationToken(OAuth2User principal, Collection authorities, + String authorizedClientRegistrationId) { + super(principal, authorities, authorizedClientRegistrationId); + } + + } + +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java new file mode 100644 index 00000000000..6b8f82a8bd0 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configurers.oauth2.client; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.oidc.authentication.event.OidcUserRefreshedEvent; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers; +import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Tests for {@link OidcUserRefreshedEventListener}. + * + * @author Steve Riesenberg + */ +public class OidcUserRefreshedEventListenerTests { + + private OidcUserRefreshedEventListener eventListener; + + private SecurityContextRepository securityContextRepository; + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + @BeforeEach + public void setUp() { + this.securityContextRepository = mock(SecurityContextRepository.class); + this.eventListener = new OidcUserRefreshedEventListener(); + this.eventListener.setSecurityContextRepository(this.securityContextRepository); + + this.request = new MockHttpServletRequest("GET", ""); + this.request.setServletPath("/"); + this.response = new MockHttpServletResponse(); + } + + @AfterEach + public void cleanUp() { + SecurityContextHolder.clearContext(); + RequestContextHolder.resetRequestAttributes(); + } + + @Test + public void setSecurityContextHolderStrategyWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setSecurityContextHolderStrategy(null)) + .withMessage("securityContextHolderStrategy cannot be null"); + } + + @Test + public void setSecurityContextRepositoryWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setSecurityContextRepository(null)) + .withMessage("securityContextRepository cannot be null"); + } + + @Test + public void onApplicationEventWhenRequestAttributesSetThenSecurityContextSaved() { + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.request, this.response)); + + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.oidcAccessTokenResponse() + .build(); + OidcUser oldOidcUser = TestOidcUsers.create(); + OidcUser newOidcUser = TestOidcUsers.create(); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(newOidcUser, + newOidcUser.getAuthorities(), "test"); + OidcUserRefreshedEvent event = new OidcUserRefreshedEvent(accessTokenResponse, oldOidcUser, newOidcUser, + authentication); + this.eventListener.onApplicationEvent(event); + + ArgumentCaptor securityContextCaptor = ArgumentCaptor.forClass(SecurityContext.class); + verify(this.securityContextRepository).saveContext(securityContextCaptor.capture(), eq(this.request), + eq(this.response)); + verifyNoMoreInteractions(this.securityContextRepository); + + SecurityContext securityContext = securityContextCaptor.getValue(); + assertThat(securityContext).isNotNull(); + assertThat(securityContext).isSameAs(SecurityContextHolder.getContext()); + assertThat(securityContext.getAuthentication()).isSameAs(authentication); + } + + @Test + public void onApplicationEventWhenRequestAttributesNotSetThenSecurityContextNotSaved() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.oidcAccessTokenResponse() + .build(); + OidcUser oldOidcUser = TestOidcUsers.create(); + OidcUser newOidcUser = TestOidcUsers.create(); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(newOidcUser, + newOidcUser.getAuthorities(), "test"); + OidcUserRefreshedEvent event = new OidcUserRefreshedEvent(accessTokenResponse, oldOidcUser, newOidcUser, + authentication); + OidcUserRefreshedEventListener eventListener = new OidcUserRefreshedEventListener(); + eventListener.setSecurityContextRepository(this.securityContextRepository); + eventListener.onApplicationEvent(event); + verifyNoInteractions(this.securityContextRepository); + + SecurityContext securityContext = SecurityContextHolder.getContext(); + assertThat(securityContext).isNotNull(); + assertThat(securityContext.getAuthentication()).isSameAs(authentication); + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java index c0c8bee93ee..316b2d5c131 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import java.util.Map; import java.util.function.Consumer; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; @@ -359,6 +360,8 @@ public final class RefreshTokenGrantBuilder implements Builder { private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private ApplicationEventPublisher eventPublisher; + private Duration clockSkew; private Clock clock; @@ -379,6 +382,18 @@ public RefreshTokenGrantBuilder accessTokenResponseClient( return this; } + /** + * Sets the {@link ApplicationEventPublisher} used when an access token is + * refreshed. + * @param eventPublisher the {@link ApplicationEventPublisher} + * @return the {@link RefreshTokenGrantBuilder} + * @since 6.5 + */ + public RefreshTokenGrantBuilder eventPublisher(ApplicationEventPublisher eventPublisher) { + this.eventPublisher = eventPublisher; + return this; + } + /** * Sets the maximum acceptable clock skew, which is used when checking the access * token expiry. An access token is considered expired if @@ -414,6 +429,9 @@ public OAuth2AuthorizedClientProvider build() { if (this.accessTokenResponseClient != null) { authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); } + if (this.eventPublisher != null) { + authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher); + } if (this.clockSkew != null) { authorizedClientProvider.setClockSkew(this.clockSkew); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index 410a33fda18..e306c7e9fe6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,10 +24,13 @@ import java.util.HashSet; import java.util.Set; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.lang.Nullable; import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.event.OAuth2AuthorizedClientRefreshedEvent; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Token; @@ -43,10 +46,13 @@ * @see OAuth2AuthorizedClientProvider * @see DefaultRefreshTokenTokenResponseClient */ -public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { +public final class RefreshTokenOAuth2AuthorizedClientProvider + implements OAuth2AuthorizedClientProvider, ApplicationEventPublisherAware { private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + private ApplicationEventPublisher applicationEventPublisher; + private Duration clockSkew = Duration.ofSeconds(60); private Clock clock = Clock.systemUTC(); @@ -91,8 +97,18 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(), authorizedClient.getRefreshToken(), scopes); OAuth2AccessTokenResponse tokenResponse = getTokenResponse(authorizedClient, refreshTokenGrantRequest); - return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(), - context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); + + OAuth2AuthorizedClient refreshedAuthorizedClient = new OAuth2AuthorizedClient( + authorizedClient.getClientRegistration(), context.getPrincipal().getName(), + tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); + + if (this.applicationEventPublisher != null) { + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + tokenResponse, refreshedAuthorizedClient); + this.applicationEventPublisher.publishEvent(authorizedClientRefreshedEvent); + } + + return refreshedAuthorizedClient; } private OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizedClient authorizedClient, @@ -149,4 +165,10 @@ public void setClock(Clock clock) { this.clock = clock; } + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { + Assert.notNull(applicationEventPublisher, "applicationEventPublisher cannot be null"); + this.applicationEventPublisher = applicationEventPublisher; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2AuthorizedClientRefreshedEvent.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2AuthorizedClientRefreshedEvent.java new file mode 100644 index 00000000000..aa70bf76f35 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2AuthorizedClientRefreshedEvent.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.event; + +import java.io.Serial; + +import org.springframework.context.ApplicationEvent; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; + +/** + * An event that is published when an {@link OAuth2AuthorizedClient} is refreshed as a + * result of using a {@code refresh_token} to obtain an OAuth 2.0 Access Token Response. + * + * @author Steve Riesenberg + * @since 6.5 + */ +public final class OAuth2AuthorizedClientRefreshedEvent extends ApplicationEvent { + + @Serial + private static final long serialVersionUID = -2178028089321556476L; + + private final OAuth2AuthorizedClient authorizedClient; + + /** + * Creates a new instance with the provided parameters. + * @param accessTokenResponse the {@link OAuth2AccessTokenResponse} that triggered the + * event + * @param authorizedClient the refreshed {@link OAuth2AuthorizedClient} + */ + public OAuth2AuthorizedClientRefreshedEvent(OAuth2AccessTokenResponse accessTokenResponse, + OAuth2AuthorizedClient authorizedClient) { + super(accessTokenResponse); + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + this.authorizedClient = authorizedClient; + } + + /** + * Returns the {@link OAuth2AccessTokenResponse} that triggered the event. + * @return the access token response + */ + public OAuth2AccessTokenResponse getAccessTokenResponse() { + return (OAuth2AccessTokenResponse) this.getSource(); + } + + /** + * Returns the refreshed {@link OAuth2AuthorizedClient}. + * @return the authorized client + */ + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListener.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListener.java new file mode 100644 index 00000000000..b6f8c45b499 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListener.java @@ -0,0 +1,219 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.authentication; + +import java.util.Collection; +import java.util.Map; + +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; +import org.springframework.context.ApplicationListener; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; +import org.springframework.security.oauth2.client.event.OAuth2AuthorizedClientRefreshedEvent; +import org.springframework.security.oauth2.client.oidc.authentication.event.OidcUserRefreshedEvent; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtDecoderFactory; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * An {@link ApplicationListener} that listens for events of type + * {@link OAuth2AuthorizedClientRefreshedEvent} and publishes an event of type + * {@link OidcUserRefreshedEvent} in order to refresh an {@link OidcUser}. + * + * @author Steve Riesenberg + * @since 6.5 + * @see org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider + * @see OAuth2AuthorizedClientRefreshedEvent + * @see OidcUserRefreshedEvent + */ +public final class OidcAuthorizedClientRefreshedEventListener + implements ApplicationEventPublisherAware, ApplicationListener { + + private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token"; + + private static final String INVALID_NONCE_ERROR_CODE = "invalid_nonce"; + + private OAuth2UserService userService = new OidcUserService(); + + private JwtDecoderFactory jwtDecoderFactory = new OidcIdTokenDecoderFactory(); + + private GrantedAuthoritiesMapper authoritiesMapper = (authorities) -> authorities; + + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + + private ApplicationEventPublisher applicationEventPublisher; + + @Override + public void onApplicationEvent(OAuth2AuthorizedClientRefreshedEvent event) { + if (this.applicationEventPublisher == null) { + return; + } + + // The response must contain the openid scope + OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse(); + if (!accessTokenResponse.getAccessToken().getScopes().contains(OidcScopes.OPENID)) { + return; + } + + // The response must contain an id_token + Map additionalParameters = accessTokenResponse.getAdditionalParameters(); + if (!StringUtils.hasText((String) additionalParameters.get(OidcParameterNames.ID_TOKEN))) { + return; + } + + // The current authentication must be an OAuth2AuthenticationToken + Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication(); + if (!(authentication instanceof OAuth2AuthenticationToken authenticationToken) + || authenticationToken.getClass() != OAuth2AuthenticationToken.class) { + // This event listener only handles the default authentication result. If the + // application customizes the authentication result, then a custom event + // handler should be provided. + return; + } + + // The current principal must be an OidcUser + if (!(authenticationToken.getPrincipal() instanceof OidcUser existingOidcUser)) { + return; + } + + // The registrationId must match the one used to log in + ClientRegistration clientRegistration = event.getAuthorizedClient().getClientRegistration(); + if (!authenticationToken.getAuthorizedClientRegistrationId().equals(clientRegistration.getRegistrationId())) { + return; + } + + // Refresh the OidcUser and send a user refreshed event + OidcIdToken idToken = createOidcToken(clientRegistration, accessTokenResponse); + validateNonce(existingOidcUser, idToken); + OidcUserRequest userRequest = new OidcUserRequest(clientRegistration, accessTokenResponse.getAccessToken(), + idToken, additionalParameters); + OidcUser oidcUser = this.userService.loadUser(userRequest); + Collection mappedAuthorities = this.authoritiesMapper + .mapAuthorities(oidcUser.getAuthorities()); + OAuth2AuthenticationToken authenticationResult = new OAuth2AuthenticationToken(oidcUser, mappedAuthorities, + clientRegistration.getRegistrationId()); + authenticationResult.setDetails(authenticationToken.getDetails()); + OidcUserRefreshedEvent oidcUserRefreshedEvent = new OidcUserRefreshedEvent(accessTokenResponse, + existingOidcUser, oidcUser, authenticationResult); + this.applicationEventPublisher.publishEvent(oidcUserRefreshedEvent); + } + + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + + /** + * Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature + * verification. The factory returns a {@link JwtDecoder} associated to the provided + * {@link ClientRegistration}. + * @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken} + * signature verification + */ + public void setJwtDecoderFactory(JwtDecoderFactory jwtDecoderFactory) { + Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null"); + this.jwtDecoderFactory = jwtDecoderFactory; + } + + /** + * Sets the {@link OAuth2UserService} used for obtaining the user attributes of the + * End-User from the UserInfo Endpoint. + * @param userService the service used for obtaining the user attributes of the + * End-User from the UserInfo Endpoint + */ + public void setUserService(OAuth2UserService userService) { + Assert.notNull(userService, "userService cannot be null"); + this.userService = userService; + } + + /** + * Sets the {@link GrantedAuthoritiesMapper} used for mapping + * {@link OidcUser#getAuthorities()}} to a new set of authorities which will be + * associated to the {@link OAuth2LoginAuthenticationToken}. + * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the + * user's authorities + */ + public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { + Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null"); + this.authoritiesMapper = authoritiesMapper; + } + + /** + * Sets the {@link ApplicationEventPublisher} to be used. + * @param applicationEventPublisher event publisher to be used + */ + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { + Assert.notNull(applicationEventPublisher, "applicationEventPublisher cannot be null"); + this.applicationEventPublisher = applicationEventPublisher; + } + + private OidcIdToken createOidcToken(ClientRegistration clientRegistration, + OAuth2AccessTokenResponse accessTokenResponse) { + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); + Jwt jwt = getJwt(accessTokenResponse, jwtDecoder); + return new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims()); + } + + private Jwt getJwt(OAuth2AccessTokenResponse accessTokenResponse, JwtDecoder jwtDecoder) { + try { + Map parameters = accessTokenResponse.getAdditionalParameters(); + return jwtDecoder.decode((String) parameters.get(OidcParameterNames.ID_TOKEN)); + } + catch (JwtException ex) { + OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null); + throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex); + } + } + + private void validateNonce(OidcUser existingOidcUser, OidcIdToken idToken) { + if (!StringUtils.hasText(idToken.getNonce())) { + return; + } + + if (!idToken.getNonce().equals(existingOidcUser.getNonce())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/event/OidcUserRefreshedEvent.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/event/OidcUserRefreshedEvent.java new file mode 100644 index 00000000000..eda81717396 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/event/OidcUserRefreshedEvent.java @@ -0,0 +1,98 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.authentication.event; + +import java.io.Serial; + +import org.springframework.context.ApplicationEvent; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.util.Assert; + +/** + * An event that is published when an {@link OidcUser} is refreshed as a result of using a + * {@code refresh_token} to obtain an OAuth 2.0 Access Token Response that contains an + * {@code id_token}. + * + * @author Steve Riesenberg + * @since 6.5 + * @see org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider + */ +public final class OidcUserRefreshedEvent extends ApplicationEvent { + + @Serial + private static final long serialVersionUID = 2657442604286019694L; + + private final OidcUser oldOidcUser; + + private final OidcUser newOidcUser; + + private final Authentication authentication; + + /** + * Creates a new instance with the provided parameters. + * @param accessTokenResponse the {@link OAuth2AccessTokenResponse} that triggered the + * event + * @param oldOidcUser the original {@link OidcUser} + * @param newOidcUser the refreshed {@link OidcUser} + * @param authentication the authentication result + */ + public OidcUserRefreshedEvent(OAuth2AccessTokenResponse accessTokenResponse, OidcUser oldOidcUser, + OidcUser newOidcUser, Authentication authentication) { + super(accessTokenResponse); + Assert.notNull(oldOidcUser, "oldOidcUser cannot be null"); + Assert.notNull(newOidcUser, "newOidcUser cannot be null"); + Assert.notNull(authentication, "authentication cannot be null"); + this.oldOidcUser = oldOidcUser; + this.newOidcUser = newOidcUser; + this.authentication = authentication; + } + + /** + * Returns the {@link OAuth2AccessTokenResponse} that triggered the event. + * @return the access token response + */ + public OAuth2AccessTokenResponse getAccessTokenResponse() { + return (OAuth2AccessTokenResponse) this.getSource(); + } + + /** + * Returns the original {@link OidcUser}. + * @return the original user + */ + public OidcUser getOldOidcUser() { + return this.oldOidcUser; + } + + /** + * Returns the refreshed {@link OidcUser}. + * @return the refreshed user + */ + public OidcUser getNewOidcUser() { + return this.newOidcUser; + } + + /** + * Returns the authentication result. + * @return the authentication result + */ + public Authentication getAuthentication() { + return this.authentication; + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java index 86ae003eff2..5b7fc551f36 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,10 +25,12 @@ import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.event.OAuth2AuthorizedClientRefreshedEvent; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -251,4 +253,55 @@ public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllega + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); } + @Test + public void shouldPublishEventWhenTokenRefreshed() { + OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher(); + this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher); + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses + .accessTokenResponse() + .refreshToken("new-refresh-token") + .build(); + // @formatter:on + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + this.authorizedClientProvider.authorize(authorizationContext); + assertThat(eventPublisher.flag).isTrue(); + } + + @Test + public void shouldNotPublishEventWhenTokenNotRefreshed() { + OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher(); + this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + this.authorizedClientProvider.authorize(authorizationContext); + assertThat(eventPublisher.flag).isFalse(); + } + + private static class OAuth2TokenRefreshedAwareEventPublisher implements ApplicationEventPublisher { + + Boolean flag = false; + + @Override + public void publishEvent(Object event) { + if (OAuth2AuthorizedClientRefreshedEvent.class.isAssignableFrom(event.getClass())) { + this.flag = true; + } + } + + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListenerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListenerTests.java new file mode 100644 index 00000000000..64a94a8f37b --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListenerTests.java @@ -0,0 +1,420 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.authentication; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.event.OAuth2AuthorizedClientRefreshedEvent; +import org.springframework.security.oauth2.client.oidc.authentication.event.OidcUserRefreshedEvent; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.security.oauth2.jwt.TestJwts; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Tests for {@link OidcAuthorizedClientRefreshedEventListener}. + * + * @author Steve Riesenberg + */ +public class OidcAuthorizedClientRefreshedEventListenerTests { + + private static final String SUBJECT = "surfer-dude"; + + private static final String ACCESS_TOKEN_VALUE = "hang-ten"; + + private static final String REFRESH_TOKEN_VALUE = "surfs-up"; + + private static final String ID_TOKEN_VALUE = "beach-break"; + + private OidcAuthorizedClientRefreshedEventListener eventListener; + + private SecurityContextHolderStrategy securityContextHolderStrategy; + + private JwtDecoder jwtDecoder; + + private OidcUserService userService; + + private ApplicationEventPublisher applicationEventPublisher; + + private ClientRegistration clientRegistration; + + private OAuth2AuthorizedClient authorizedClient; + + private OAuth2AccessTokenResponse accessTokenResponse; + + private Jwt jwt; + + private OidcUser oidcUser; + + @BeforeEach + public void setUp() { + this.jwtDecoder = mock(JwtDecoder.class); + this.userService = mock(OidcUserService.class); + this.securityContextHolderStrategy = mock(SecurityContextHolderStrategy.class); + this.applicationEventPublisher = mock(ApplicationEventPublisher.class); + + this.eventListener = new OidcAuthorizedClientRefreshedEventListener(); + this.eventListener.setUserService(this.userService); + this.eventListener.setJwtDecoderFactory((clientRegistration) -> this.jwtDecoder); + this.eventListener.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); + this.eventListener.setApplicationEventPublisher(this.applicationEventPublisher); + + this.clientRegistration = TestClientRegistrations.clientRegistration().scope(OidcScopes.OPENID).build(); + this.authorizedClient = createAuthorizedClient(this.clientRegistration); + this.accessTokenResponse = createAccessTokenResponse(OidcScopes.OPENID); + this.jwt = createJwt(); + this.oidcUser = createOidcUser(); + } + + @Test + public void setSecurityContextHolderStrategyWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setSecurityContextHolderStrategy(null)) + .withMessage("securityContextHolderStrategy cannot be null"); + } + + @Test + public void setJwtDecoderFactoryWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setJwtDecoderFactory(null)) + .withMessage("jwtDecoderFactory cannot be null"); + } + + @Test + public void setUserServiceWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setUserService(null)) + .withMessage("userService cannot be null"); + } + + @Test + public void setAuthoritiesMapperWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setAuthoritiesMapper(null)) + .withMessage("authoritiesMapper cannot be null"); + } + + @Test + public void setApplicationEventPublisherWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setApplicationEventPublisher(null)) + .withMessage("applicationEventPublisher cannot be null"); + } + + @Test + public void onApplicationEventWhenAccessTokenResponseMissingIdTokenThenOidcUserRefreshedEventNotPublished() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .scopes(Set.of(OidcScopes.OPENID)) + .build(); + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + verifyNoInteractions(this.securityContextHolderStrategy, this.jwtDecoder, this.userService, + this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenAccessTokenResponseMissingOpenidScopeThenOidcUserRefreshedEventNotPublished() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.oidcAccessTokenResponse() + .scopes(Set.of()) + .build(); + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + verifyNoInteractions(this.securityContextHolderStrategy, this.jwtDecoder, this.userService, + this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenAuthenticationIsNotOAuth2ThenOidcUserRefreshedEventNotPublished() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken(SUBJECT, null); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + + verify(this.securityContextHolderStrategy).getContext(); + verifyNoMoreInteractions(this.securityContextHolderStrategy); + verifyNoInteractions(this.jwtDecoder, this.userService, this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenAuthenticationIsCustomThenOidcUserRefreshedEventNotPublished() { + OAuth2AuthenticationToken authentication = new CustomOAuth2AuthenticationToken(this.oidcUser, + this.oidcUser.getAuthorities(), this.clientRegistration.getRegistrationId()); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + + verify(this.securityContextHolderStrategy).getContext(); + verifyNoMoreInteractions(this.securityContextHolderStrategy); + verifyNoInteractions(this.jwtDecoder, this.userService, this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenPrincipalIsOAuth2UserThenOidcUserRefreshedEventNotPublished() { + Map attributes = Map.of(StandardClaimNames.SUB, SUBJECT); + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("OAUTH2_USER"), attributes, + StandardClaimNames.SUB); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(oauth2User, + oauth2User.getAuthorities(), this.clientRegistration.getRegistrationId()); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + + verify(this.securityContextHolderStrategy).getContext(); + verifyNoMoreInteractions(this.securityContextHolderStrategy); + verifyNoInteractions(this.jwtDecoder, this.userService, this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenClientRegistrationIdDoesNotMatchThenOidcUserRefreshedEventNotPublished() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .registrationId("test") + .build(); + OAuth2AuthenticationToken authentication = createAuthenticationToken(clientRegistration); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + + verify(this.securityContextHolderStrategy).getContext(); + verifyNoMoreInteractions(this.securityContextHolderStrategy); + verifyNoInteractions(this.jwtDecoder, this.userService, this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenAccessTokenResponseIncludesIdTokenThenPublishOidcUserRefreshedEvent() { + OAuth2AuthenticationToken authentication = createAuthenticationToken(this.clientRegistration); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + given(this.jwtDecoder.decode(anyString())).willReturn(this.jwt); + given(this.userService.loadUser(any(OidcUserRequest.class))).willReturn(this.oidcUser); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + + ArgumentCaptor userRequestCaptor = ArgumentCaptor.forClass(OidcUserRequest.class); + ArgumentCaptor userRefreshedEventCaptor = ArgumentCaptor + .forClass(OidcUserRefreshedEvent.class); + verify(this.securityContextHolderStrategy).getContext(); + verify(this.jwtDecoder).decode(this.jwt.getTokenValue()); + verify(this.userService).loadUser(userRequestCaptor.capture()); + verify(this.applicationEventPublisher).publishEvent(userRefreshedEventCaptor.capture()); + verifyNoMoreInteractions(this.securityContextHolderStrategy, this.jwtDecoder, this.userService, + this.applicationEventPublisher); + + OidcUserRequest userRequest = userRequestCaptor.getValue(); + assertThat(userRequest.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(userRequest.getAccessToken()).isSameAs(this.accessTokenResponse.getAccessToken()); + assertThat(userRequest.getIdToken().getTokenValue()).isEqualTo(this.jwt.getTokenValue()); + + OidcUserRefreshedEvent userRefreshedEvent = userRefreshedEventCaptor.getValue(); + assertThat(userRefreshedEvent.getAccessTokenResponse()).isSameAs(this.accessTokenResponse); + assertThat(userRefreshedEvent.getOldOidcUser()).isSameAs(authentication.getPrincipal()); + assertThat(userRefreshedEvent.getNewOidcUser()).isSameAs(this.oidcUser); + assertThat(userRefreshedEvent.getAuthentication()).isNotSameAs(authentication); + assertThat(userRefreshedEvent.getAuthentication()).isInstanceOf(OAuth2AuthenticationToken.class); + + OAuth2AuthenticationToken authenticationResult = (OAuth2AuthenticationToken) userRefreshedEvent + .getAuthentication(); + assertThat(authenticationResult.getPrincipal()).isEqualTo(this.oidcUser); + assertThat(authenticationResult.getAuthorities()).containsExactlyElementsOf(this.oidcUser.getAuthorities()); + assertThat(authenticationResult.getAuthorizedClientRegistrationId()) + .isEqualTo(this.clientRegistration.getRegistrationId()); + } + + @Test + public void onApplicationEventWhenIdTokenNonceDoesNotMatchThenThrowsOAuth2AuthenticationException() { + Jwt jwt = TestJwts.jwt().claim(IdTokenClaimNames.NONCE, "invalid").build(); + OAuth2AuthenticationToken authentication = createAuthenticationToken(this.clientRegistration); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + given(this.jwtDecoder.decode(anyString())).willReturn(jwt); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo("invalid_nonce"); + verify(this.securityContextHolderStrategy).getContext(); + verify(this.jwtDecoder).decode(this.jwt.getTokenValue()); + verifyNoMoreInteractions(this.securityContextHolderStrategy, this.jwtDecoder); + verifyNoInteractions(this.userService, this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenInvalidIdTokenThenThrowsOAuth2AuthenticationException() { + OAuth2AuthenticationToken authentication = createAuthenticationToken(this.clientRegistration); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + given(this.jwtDecoder.decode(anyString())).willThrow(new JwtException("Invalid token")); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo("invalid_id_token"); + verify(this.securityContextHolderStrategy).getContext(); + verify(this.jwtDecoder).decode(this.jwt.getTokenValue()); + verifyNoMoreInteractions(this.securityContextHolderStrategy, this.jwtDecoder); + verifyNoInteractions(this.userService, this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenCustomAuthoritiesMapperSetThenUsed() { + OAuth2AuthenticationToken authentication = createAuthenticationToken(this.clientRegistration); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + given(this.jwtDecoder.decode(anyString())).willReturn(this.jwt); + given(this.userService.loadUser(any(OidcUserRequest.class))).willReturn(this.oidcUser); + + GrantedAuthoritiesMapper grantedAuthoritiesMapper = mock(GrantedAuthoritiesMapper.class); + this.eventListener.setAuthoritiesMapper(grantedAuthoritiesMapper); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + + verify(grantedAuthoritiesMapper).mapAuthorities(this.oidcUser.getAuthorities()); + verifyNoMoreInteractions(grantedAuthoritiesMapper); + } + + private static OAuth2AuthorizedClient createAuthorizedClient(ClientRegistration clientRegistration) { + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(30, ChronoUnit.SECONDS); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, ACCESS_TOKEN_VALUE, + issuedAt, expiresAt, clientRegistration.getScopes()); + + return new OAuth2AuthorizedClient(clientRegistration, SUBJECT, accessToken); + } + + private static OAuth2AccessTokenResponse createAccessTokenResponse(String... scope) { + Set scopes = Set.of(scope); + Map additionalParameters = new HashMap<>(); + if (scopes.contains(OidcScopes.OPENID)) { + additionalParameters.put(OidcParameterNames.ID_TOKEN, ID_TOKEN_VALUE); + } + + return OAuth2AccessTokenResponse.withToken(ACCESS_TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .scopes(scopes) + .refreshToken(REFRESH_TOKEN_VALUE) + .expiresIn(60L) + .additionalParameters(additionalParameters) + .build(); + } + + private static Jwt createJwt() { + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.MINUTES); + return TestJwts.jwt() + .subject(SUBJECT) + .tokenValue(ID_TOKEN_VALUE) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .claim(OidcParameterNames.NONCE, "nonce") + .build(); + } + + private static OidcUser createOidcUser() { + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.SUB, SUBJECT); + claims.put(IdTokenClaimNames.ISS, "issuer"); + claims.put(IdTokenClaimNames.AUD, List.of("audience1", "audience2")); + claims.put(IdTokenClaimNames.NONCE, "nonce"); + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.MINUTES); + OidcIdToken idToken = new OidcIdToken(ID_TOKEN_VALUE, issuedAt, expiresAt, claims); + + return new DefaultOidcUser(AuthorityUtils.createAuthorityList("OIDC_USER"), idToken); + } + + private static OAuth2AuthenticationToken createAuthenticationToken(ClientRegistration clientRegistration) { + OidcUser oidcUser = createOidcUser(); + return new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), + clientRegistration.getRegistrationId()); + } + + private static final class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken { + + CustomOAuth2AuthenticationToken(OAuth2User principal, Collection authorities, + String authorizedClientRegistrationId) { + super(principal, authorities, authorizedClientRegistrationId); + } + + } + +}