From b9cc6e1b538123568a63070ebcea62c240bb90db Mon Sep 17 00:00:00 2001 From: Hao Date: Fri, 14 Feb 2025 00:08:50 +0800 Subject: [PATCH 1/2] Ensure ID Token is updated after refresh token Signed-off-by: Hao --- .../OAuth2ClientConfiguration.java | 16 +- .../oauth2/client/OAuth2LoginConfigurer.java | 10 + ...Auth2AuthorizedClientManagerRegistrar.java | 7 + ...OAuth2AuthorizedClientProviderBuilder.java | 17 ++ ...shTokenOAuth2AuthorizedClientProvider.java | 26 +- .../event/OAuth2TokenRefreshedEvent.java | 47 +++ .../RefreshOidcIdTokenHandler.java | 139 +++++++++ ...enOAuth2AuthorizedClientProviderTests.java | 53 ++++ .../RefreshOidcIdTokenHandlerTests.java | 284 ++++++++++++++++++ 9 files changed, 595 insertions(+), 4 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandlerTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 13c9a1b3c07..55de62810d5 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -34,6 +34,9 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.annotation.AnnotationBeanNameGenerator; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -160,7 +163,7 @@ private OAuth2AuthorizedClientManager getAuthorizedClientManager() { * @since 6.2.0 */ static final class OAuth2AuthorizedClientManagerRegistrar - implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware { + implements ApplicationContextAware, BeanDefinitionRegistryPostProcessor, BeanFactoryAware { static final String BEAN_NAME = "authorizedClientManagerRegistrar"; @@ -179,6 +182,8 @@ static final class OAuth2AuthorizedClientManagerRegistrar private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator(); + private ApplicationEventPublisher eventPublisher; + private ListableBeanFactory beanFactory; @Override @@ -302,6 +307,10 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider( authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); } + if (this.eventPublisher != null) { + authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher); + } + return authorizedClientProvider; } @@ -423,6 +432,11 @@ private T getBeanOfType(ResolvableType resolvableType) { return objectProvider.getIfAvailable(); } + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.eventPublisher = applicationContext; + } + } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index ad93db75d3d..155918dddb0 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -57,6 +57,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider; +import org.springframework.security.oauth2.client.oidc.authentication.RefreshOidcIdTokenHandler; import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry; import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry; @@ -394,6 +395,15 @@ public void init(B http) throws Exception { oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper); } http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider)); + + RefreshOidcIdTokenHandler refreshOidcIdTokenHandler = new RefreshOidcIdTokenHandler(); + if (this.getSecurityContextHolderStrategy() != null) { + refreshOidcIdTokenHandler.setSecurityContextHolderStrategy(this.getSecurityContextHolderStrategy()); + } + if (jwtDecoderFactory != null) { + refreshOidcIdTokenHandler.setJwtDecoderFactory(jwtDecoderFactory); + } + registerDelegateApplicationListener(refreshOidcIdTokenHandler); } else { http.authenticationProvider(new OidcAuthenticationRequestChecker()); diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java b/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java index 669d6f7f67f..d2252435f7b 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java @@ -34,6 +34,7 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.annotation.AnnotationBeanNameGenerator; import org.springframework.core.ResolvableType; import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; @@ -197,6 +198,12 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider( authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); } + ApplicationEventPublisher applicationEventPublisher = getBeanOfType( + ResolvableType.forClass(ApplicationEventPublisher.class)); + if (applicationEventPublisher != null) { + authorizedClientProvider.setApplicationEventPublisher(applicationEventPublisher); + } + return authorizedClientProvider; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java index c0c8bee93ee..bcd130063e6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.function.Consumer; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; @@ -359,6 +360,8 @@ public final class RefreshTokenGrantBuilder implements Builder { private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private ApplicationEventPublisher eventPublisher; + private Duration clockSkew; private Clock clock; @@ -379,6 +382,17 @@ public RefreshTokenGrantBuilder accessTokenResponseClient( return this; } + /** + * Sets the {@link ApplicationEventPublisher} used when an access token is + * refreshed. + * @param eventPublisher the {@link ApplicationEventPublisher} + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder eventPublisher(ApplicationEventPublisher eventPublisher) { + this.eventPublisher = eventPublisher; + return this; + } + /** * Sets the maximum acceptable clock skew, which is used when checking the access * token expiry. An access token is considered expired if @@ -414,6 +428,9 @@ public OAuth2AuthorizedClientProvider build() { if (this.accessTokenResponseClient != null) { authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); } + if (this.eventPublisher != null) { + authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher); + } if (this.clockSkew != null) { authorizedClientProvider.setClockSkew(this.clockSkew); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index 410a33fda18..17dc2ad16b9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -24,10 +24,13 @@ import java.util.HashSet; import java.util.Set; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.lang.Nullable; import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Token; @@ -43,10 +46,13 @@ * @see OAuth2AuthorizedClientProvider * @see DefaultRefreshTokenTokenResponseClient */ -public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { +public final class RefreshTokenOAuth2AuthorizedClientProvider + implements OAuth2AuthorizedClientProvider, ApplicationEventPublisherAware { private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + private ApplicationEventPublisher eventPublisher; + private Duration clockSkew = Duration.ofSeconds(60); private Clock clock = Clock.systemUTC(); @@ -91,8 +97,17 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(), authorizedClient.getRefreshToken(), scopes); OAuth2AccessTokenResponse tokenResponse = getTokenResponse(authorizedClient, refreshTokenGrantRequest); - return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(), - context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); + + OAuth2AuthorizedClient updatedOAuth2AuthorizedClient = new OAuth2AuthorizedClient( + authorizedClient.getClientRegistration(), context.getPrincipal().getName(), + tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); + + if (this.eventPublisher != null) { + this.eventPublisher + .publishEvent(new OAuth2TokenRefreshedEvent(this, updatedOAuth2AuthorizedClient, tokenResponse)); + } + + return updatedOAuth2AuthorizedClient; } private OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizedClient authorizedClient, @@ -149,4 +164,9 @@ public void setClock(Clock clock) { this.clock = clock; } + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { + this.eventPublisher = applicationEventPublisher; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java new file mode 100644 index 00000000000..f92091d4cd7 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.event; + +import org.springframework.context.ApplicationEvent; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; + +/** + * An event that is published when an OAuth2 access token is refreshed. + */ +public class OAuth2TokenRefreshedEvent extends ApplicationEvent { + + private final OAuth2AuthorizedClient authorizedClient; + + private final OAuth2AccessTokenResponse accessTokenResponse; + + public OAuth2TokenRefreshedEvent(Object source, OAuth2AuthorizedClient authorizedClient, + OAuth2AccessTokenResponse accessTokenResponse) { + super(source); + this.authorizedClient = authorizedClient; + this.accessTokenResponse = accessTokenResponse; + } + + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } + + public OAuth2AccessTokenResponse getAccessTokenResponse() { + return this.accessTokenResponse; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java new file mode 100644 index 00000000000..d1af1a4f485 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.authentication; + +import java.util.Map; + +import org.springframework.context.ApplicationListener; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtDecoderFactory; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.util.Assert; + +/** + * An {@link ApplicationListener} that listens for {@link OAuth2TokenRefreshedEvent}s + */ +public class RefreshOidcIdTokenHandler implements ApplicationListener { + + private static final String MISSING_ID_TOKEN_ERROR_CODE = "missing_id_token"; + + private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token"; + + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + + private JwtDecoderFactory jwtDecoderFactory = new OidcIdTokenDecoderFactory(); + + @Override + public void onApplicationEvent(OAuth2TokenRefreshedEvent event) { + OAuth2AuthorizedClient authorizedClient = event.getAuthorizedClient(); + + if (!authorizedClient.getClientRegistration().getScopes().contains(OidcScopes.OPENID)) { + return; + } + + Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication(); + if (!(authentication instanceof OAuth2AuthenticationToken oauth2Authentication)) { + return; + } + if (!(authentication.getPrincipal() instanceof DefaultOidcUser defaultOidcUser)) { + return; + } + + OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse(); + + String idToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN); + if (idToken == null || idToken.isBlank()) { + OAuth2Error missingIdTokenError = new OAuth2Error(MISSING_ID_TOKEN_ERROR_CODE, + "ID token is missing in the token response", null); + throw new OAuth2AuthenticationException(missingIdTokenError, missingIdTokenError.toString()); + } + + ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); + OidcIdToken refreshedOidcToken = createOidcToken(clientRegistration, accessTokenResponse); + updateSecurityContext(oauth2Authentication, defaultOidcUser, refreshedOidcToken); + } + + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + + /** + * Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature + * verification. The factory returns a {@link JwtDecoder} associated to the provided + * {@link ClientRegistration}. + * @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken} + * signature verification + */ + public final void setJwtDecoderFactory(JwtDecoderFactory jwtDecoderFactory) { + Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null"); + this.jwtDecoderFactory = jwtDecoderFactory; + } + + private void updateSecurityContext(OAuth2AuthenticationToken oauth2Authentication, DefaultOidcUser defaultOidcUser, + OidcIdToken refreshedOidcToken) { + OidcUser oidcUser = new DefaultOidcUser(defaultOidcUser.getAuthorities(), refreshedOidcToken, + defaultOidcUser.getUserInfo(), StandardClaimNames.SUB); + + SecurityContext context = this.securityContextHolderStrategy.createEmptyContext(); + context.setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), + oauth2Authentication.getAuthorizedClientRegistrationId())); + + this.securityContextHolderStrategy.setContext(context); + } + + private OidcIdToken createOidcToken(ClientRegistration clientRegistration, + OAuth2AccessTokenResponse accessTokenResponse) { + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); + Jwt jwt = getJwt(accessTokenResponse, jwtDecoder); + return new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims()); + } + + private Jwt getJwt(OAuth2AccessTokenResponse accessTokenResponse, JwtDecoder jwtDecoder) { + try { + Map parameters = accessTokenResponse.getAdditionalParameters(); + return jwtDecoder.decode((String) parameters.get(OidcParameterNames.ID_TOKEN)); + } + catch (JwtException ex) { + OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null); + throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex); + } + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java index 86ae003eff2..dc4c4232004 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -25,10 +25,12 @@ import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -251,4 +253,55 @@ public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllega + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); } + @Test + public void shouldPublishEventWhenTokenRefreshed() { + OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher(); + this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher); + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses + .accessTokenResponse() + .refreshToken("new-refresh-token") + .build(); + // @formatter:on + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + this.authorizedClientProvider.authorize(authorizationContext); + assertThat(eventPublisher.flag).isTrue(); + } + + @Test + public void shouldNotPublishEventWhenTokenNotRefreshed() { + OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher(); + this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + this.authorizedClientProvider.authorize(authorizationContext); + assertThat(eventPublisher.flag).isFalse(); + } + + private static class OAuth2TokenRefreshedAwareEventPublisher implements ApplicationEventPublisher { + + Boolean flag = false; + + @Override + public void publishEvent(Object event) { + if (OAuth2TokenRefreshedEvent.class.isAssignableFrom(event.getClass())) { + this.flag = true; + } + } + + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandlerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandlerTests.java new file mode 100644 index 00000000000..61ff5aae892 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandlerTests.java @@ -0,0 +1,284 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.authentication; + +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtDecoderFactory; +import org.springframework.security.oauth2.jwt.JwtException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +class RefreshOidcIdTokenHandlerTests { + + private static final String EXISTING_ID_TOKEN_VALUE = "id-token-value"; + + private static final String REFRESHED_ID_TOKEN_VALUE = "new-id-token-value"; + + private static final String EXISTING_ACCESS_TOKEN_VALUE = "token-value"; + + private static final String REFRESHED_ACCESS_TOKEN_VALUE = "new-token-value"; + + private RefreshOidcIdTokenHandler handler; + + private RefreshTokenOAuth2AuthorizedClientProvider provider; + + private ClientRegistration clientRegistration; + + private OAuth2AuthorizedClient authorizedClient; + + private JwtDecoder jwtDecoder; + + private SecurityContext securityContext; + + private OidcIdToken existingIdToken; + + @BeforeEach + void setUp() { + this.handler = new RefreshOidcIdTokenHandler(); + + this.clientRegistration = createClientRegistrationWithScopes(OidcScopes.OPENID); + this.authorizedClient = createAuthorizedClient(this.clientRegistration); + + this.provider = mock(RefreshTokenOAuth2AuthorizedClientProvider.class); + + JwtDecoderFactory jwtDecoderFactory = mock(JwtDecoderFactory.class); + this.jwtDecoder = mock(JwtDecoder.class); + SecurityContextHolderStrategy securityContextHolderStrategy = mock(SecurityContextHolderStrategy.class); + this.securityContext = mock(SecurityContext.class); + + this.handler.setJwtDecoderFactory(jwtDecoderFactory); + this.handler.setSecurityContextHolderStrategy(securityContextHolderStrategy); + + given(jwtDecoderFactory.createDecoder(any())).willReturn(this.jwtDecoder); + given(securityContextHolderStrategy.createEmptyContext()).willReturn(this.securityContext); + given(securityContextHolderStrategy.getContext()).willReturn(this.securityContext); + + Map claims = new HashMap<>(); + claims.put("sub", "subject"); + Jwt existingIdTokenJwt = new Jwt(EXISTING_ID_TOKEN_VALUE, Instant.now(), Instant.now().plusSeconds(3600), + Map.of("alg", "RS256"), claims); + Jwt refreshedIdTokenJwt = new Jwt(REFRESHED_ID_TOKEN_VALUE, Instant.now(), Instant.now().plusSeconds(3600), + Map.of("alg", "RS256"), claims); + + this.existingIdToken = new OidcIdToken(existingIdTokenJwt.getTokenValue(), existingIdTokenJwt.getIssuedAt(), + existingIdTokenJwt.getExpiresAt(), existingIdTokenJwt.getClaims()); + + given(this.jwtDecoder.decode(existingIdTokenJwt.getTokenValue())).willReturn(existingIdTokenJwt); + given(this.jwtDecoder.decode(refreshedIdTokenJwt.getTokenValue())).willReturn(refreshedIdTokenJwt); + } + + @Test + void handleEventWhenValidIdTokenThenUpdatesSecurityContext() { + + DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), + this.existingIdToken); + OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser, + existingUser.getAuthorities(), "registration-id"); + given(this.securityContext.getAuthentication()).willReturn(existingAuth); + + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken(REFRESHED_ACCESS_TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(3600) + .additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE)) + .build(); + + OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient, + accessTokenResponse); + this.handler.onApplicationEvent(event); + + ArgumentCaptor authenticationCaptor = ArgumentCaptor + .forClass(OAuth2AuthenticationToken.class); + verify(this.securityContext).setAuthentication(authenticationCaptor.capture()); + + OAuth2AuthenticationToken newAuthentication = authenticationCaptor.getValue(); + assertThat(newAuthentication.getPrincipal()).isInstanceOf(DefaultOidcUser.class); + DefaultOidcUser newUser = (DefaultOidcUser) newAuthentication.getPrincipal(); + assertThat(newUser.getIdToken().getTokenValue()).isEqualTo(REFRESHED_ID_TOKEN_VALUE); + } + + @Test + void handleEventWhenAuthorizedClientIsNotOidcThenDoesNothing() { + + this.clientRegistration = createClientRegistrationWithScopes("read"); + this.authorizedClient = createAuthorizedClient(this.clientRegistration); + + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken(REFRESHED_ACCESS_TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(3600) + .additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE)) + .build(); + + OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient, + accessTokenResponse); + + this.handler.onApplicationEvent(event); + + verify(this.securityContext, never()).setAuthentication(any()); + verify(this.jwtDecoder, never()).decode(any()); + } + + @Test + void handleEventWhenAuthenticationNotOAuth2AuthenticationTokenThenDoesNothing() { + + given(this.securityContext.getAuthentication()).willReturn(mock(TestingAuthenticationToken.class)); + + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken(REFRESHED_ACCESS_TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(3600) + .additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE)) + .build(); + + OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient, + accessTokenResponse); + + this.handler.onApplicationEvent(event); + + verify(this.securityContext, never()).setAuthentication(any()); + } + + @Test + void handleEventWhenNotOidcUserThenDoesNothing() { + + OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken( + new DefaultOAuth2User(Collections.emptySet(), + Collections.singletonMap("custom-attribute", "test-subject"), "custom-attribute"), + AuthorityUtils.createAuthorityList("ROLE_USER"), "registration-id"); + given(this.securityContext.getAuthentication()).willReturn(existingAuth); + + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken(REFRESHED_ACCESS_TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(3600) + .additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE)) + .build(); + + OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient, + accessTokenResponse); + + this.handler.onApplicationEvent(event); + + verify(this.securityContext, never()).setAuthentication(any()); + } + + @Test + void handleEventWhenMissingIdTokenThenThrowsException() { + + DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), + this.existingIdToken); + OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser, + existingUser.getAuthorities(), "registration-id"); + given(this.securityContext.getAuthentication()).willReturn(existingAuth); + + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken(REFRESHED_ACCESS_TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(3600) + .additionalParameters(new HashMap<>()) // missing ID token + .build(); + + OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient, + accessTokenResponse); + + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.handler.onApplicationEvent(event)) + .withMessageContaining("missing_id_token"); + } + + @Test + void handleEventWhenInvalidIdTokenThenThrowsException() { + + DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), + this.existingIdToken); + OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser, + existingUser.getAuthorities(), "registration-id"); + given(this.securityContext.getAuthentication()).willReturn(existingAuth); + + given(this.jwtDecoder.decode(any())).willThrow(new JwtException("Invalid token")); + + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken(REFRESHED_ACCESS_TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(3600) + .additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, "invalid-id-token")) + .build(); + + OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient, + accessTokenResponse); + + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.handler.onApplicationEvent(event)) + .withMessageContaining("invalid_id_token"); + } + + private ClientRegistration createClientRegistrationWithScopes(String... scope) { + return ClientRegistration.withRegistrationId("registration-id") + .clientId("client-id") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUri("http://localhost") + .scope(scope) + .authorizationUri("https://provider.com/oauth2/authorize") + .tokenUri("https://provider.com/oauth2/token") + .jwkSetUri("https://provider.com/jwk") + .userInfoUri("https://provider.com/user") + .build(); + } + + private static OAuth2AuthorizedClient createAuthorizedClient(ClientRegistration clientRegistration) { + return new OAuth2AuthorizedClient(clientRegistration, "principal-name", + new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, EXISTING_ACCESS_TOKEN_VALUE, Instant.now(), + Instant.now().plusSeconds(3600))); + } + +} From 8ab4ddd806055d13fe9f7acc29be166eef698a7e Mon Sep 17 00:00:00 2001 From: Steve Riesenberg <5248162+sjohnr@users.noreply.github.com> Date: Thu, 27 Feb 2025 14:09:10 -0600 Subject: [PATCH 2/2] Polish gh-16589 --- .../OAuth2ClientConfiguration.java | 17 +- .../oauth2/client/OAuth2LoginConfigurer.java | 34 +- .../OidcUserRefreshedEventListener.java | 86 +++ ...Auth2AuthorizedClientManagerRegistrar.java | 2 +- ...reshedEventListenerConfigurationTests.java | 504 ++++++++++++++++++ .../OidcUserRefreshedEventListenerTests.java | 135 +++++ ...OAuth2AuthorizedClientProviderBuilder.java | 3 +- ...shTokenOAuth2AuthorizedClientProvider.java | 20 +- .../OAuth2AuthorizedClientRefreshedEvent.java | 69 +++ .../event/OAuth2TokenRefreshedEvent.java | 47 -- ...uthorizedClientRefreshedEventListener.java | 219 ++++++++ .../RefreshOidcIdTokenHandler.java | 139 ----- .../event/OidcUserRefreshedEvent.java | 98 ++++ ...enOAuth2AuthorizedClientProviderTests.java | 6 +- ...izedClientRefreshedEventListenerTests.java | 420 +++++++++++++++ .../RefreshOidcIdTokenHandlerTests.java | 284 ---------- 16 files changed, 1580 insertions(+), 503 deletions(-) create mode 100644 config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListener.java create mode 100644 config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java create mode 100644 config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2AuthorizedClientRefreshedEvent.java delete mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListener.java delete mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/event/OidcUserRefreshedEvent.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListenerTests.java delete mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandlerTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 55de62810d5..0bdd2bfe340 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,9 +34,8 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; -import org.springframework.context.ApplicationContext; -import org.springframework.context.ApplicationContextAware; import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.annotation.AnnotationBeanNameGenerator; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -163,7 +162,7 @@ private OAuth2AuthorizedClientManager getAuthorizedClientManager() { * @since 6.2.0 */ static final class OAuth2AuthorizedClientManagerRegistrar - implements ApplicationContextAware, BeanDefinitionRegistryPostProcessor, BeanFactoryAware { + implements ApplicationEventPublisherAware, BeanDefinitionRegistryPostProcessor, BeanFactoryAware { static final String BEAN_NAME = "authorizedClientManagerRegistrar"; @@ -182,7 +181,7 @@ static final class OAuth2AuthorizedClientManagerRegistrar private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator(); - private ApplicationEventPublisher eventPublisher; + private ApplicationEventPublisher applicationEventPublisher; private ListableBeanFactory beanFactory; @@ -307,8 +306,8 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider( authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); } - if (this.eventPublisher != null) { - authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher); + if (this.applicationEventPublisher != null) { + authorizedClientProvider.setApplicationEventPublisher(this.applicationEventPublisher); } return authorizedClientProvider; @@ -433,8 +432,8 @@ private T getBeanOfType(ResolvableType resolvableType) { } @Override - public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { - this.eventPublisher = applicationContext; + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { + this.applicationEventPublisher = applicationEventPublisher; } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index 155918dddb0..dcacd52df18 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -57,7 +57,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider; -import org.springframework.security.oauth2.client.oidc.authentication.RefreshOidcIdTokenHandler; +import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizedClientRefreshedEventListener; import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry; import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry; @@ -91,6 +91,7 @@ import org.springframework.security.web.authentication.session.SessionAuthenticationException; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.util.matcher.AndRequestMatcher; @@ -387,23 +388,26 @@ public void init(B http) throws Exception { OAuth2UserService oidcUserService = getOidcUserService(); OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider = new OidcAuthorizationCodeAuthenticationProvider( accessTokenResponseClient, oidcUserService); + OidcAuthorizedClientRefreshedEventListener oidcAuthorizedClientRefreshedEventListener = new OidcAuthorizedClientRefreshedEventListener(); + oidcAuthorizedClientRefreshedEventListener.setUserService(oidcUserService); + oidcAuthorizedClientRefreshedEventListener + .setApplicationEventPublisher(http.getSharedObject(ApplicationContext.class)); + JwtDecoderFactory jwtDecoderFactory = this.getJwtDecoderFactoryBean(); if (jwtDecoderFactory != null) { oidcAuthorizationCodeAuthenticationProvider.setJwtDecoderFactory(jwtDecoderFactory); + oidcAuthorizedClientRefreshedEventListener.setJwtDecoderFactory(jwtDecoderFactory); } if (userAuthoritiesMapper != null) { oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper); + oidcAuthorizedClientRefreshedEventListener.setAuthoritiesMapper(userAuthoritiesMapper); } - http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider)); + oidcAuthorizationCodeAuthenticationProvider = this.postProcess(oidcAuthorizationCodeAuthenticationProvider); + http.authenticationProvider(oidcAuthorizationCodeAuthenticationProvider); - RefreshOidcIdTokenHandler refreshOidcIdTokenHandler = new RefreshOidcIdTokenHandler(); - if (this.getSecurityContextHolderStrategy() != null) { - refreshOidcIdTokenHandler.setSecurityContextHolderStrategy(this.getSecurityContextHolderStrategy()); - } - if (jwtDecoderFactory != null) { - refreshOidcIdTokenHandler.setJwtDecoderFactory(jwtDecoderFactory); - } - registerDelegateApplicationListener(refreshOidcIdTokenHandler); + oidcAuthorizedClientRefreshedEventListener = this.postProcess(oidcAuthorizedClientRefreshedEventListener); + registerDelegateApplicationListener(oidcAuthorizedClientRefreshedEventListener); + configureOidcUserRefreshedEventListener(http); } else { http.authenticationProvider(new OidcAuthenticationRequestChecker()); @@ -631,6 +635,16 @@ private void configureOidcSessionRegistry(B http) { registerDelegateApplicationListener(listener); } + private void configureOidcUserRefreshedEventListener(B http) { + OidcUserRefreshedEventListener oidcUserRefreshedEventListener = new OidcUserRefreshedEventListener(); + oidcUserRefreshedEventListener.setSecurityContextHolderStrategy(this.getSecurityContextHolderStrategy()); + SecurityContextRepository securityContextRepository = http.getSharedObject(SecurityContextRepository.class); + if (securityContextRepository != null) { + oidcUserRefreshedEventListener.setSecurityContextRepository(securityContextRepository); + } + registerDelegateApplicationListener(oidcUserRefreshedEventListener); + } + private void registerDelegateApplicationListener(ApplicationListener delegate) { DelegatingApplicationListener delegating = getBeanOrNull( ResolvableType.forType(DelegatingApplicationListener.class)); diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListener.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListener.java new file mode 100644 index 00000000000..e9dce35b69e --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListener.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configurers.oauth2.client; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.springframework.context.ApplicationListener; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.oauth2.client.oidc.authentication.event.OidcUserRefreshedEvent; +import org.springframework.security.web.context.HttpSessionSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.util.Assert; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +/** + * An {@link ApplicationListener} that listens for events of type + * {@link OidcUserRefreshedEvent} and refreshes the {@link SecurityContext}. + * + * @author Steve Riesenberg + * @since 6.5 + * @see org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider + */ +final class OidcUserRefreshedEventListener implements ApplicationListener { + + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + + private SecurityContextRepository securityContextRepository = new HttpSessionSecurityContextRepository(); + + @Override + public void onApplicationEvent(OidcUserRefreshedEvent event) { + SecurityContext securityContext = this.securityContextHolderStrategy.createEmptyContext(); + securityContext.setAuthentication(event.getAuthentication()); + this.securityContextHolderStrategy.setContext(securityContext); + + RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); + if (!(requestAttributes instanceof ServletRequestAttributes servletRequestAttributes)) { + return; + } + + HttpServletRequest request = servletRequestAttributes.getRequest(); + HttpServletResponse response = servletRequestAttributes.getResponse(); + this.securityContextRepository.saveContext(securityContext, request, response); + } + + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * @param securityContextHolderStrategy the {@link SecurityContextHolderStrategy} to + * use + */ + void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} upon + * receiving an {@link OidcUserRefreshedEvent}. + * @param securityContextRepository the {@link SecurityContextRepository} to use + */ + void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + +} diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java b/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java index d2252435f7b..dc131adcd45 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java new file mode 100644 index 00000000000..5ad7f9726e2 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java @@ -0,0 +1,504 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configurers.oauth2.client; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import jakarta.servlet.http.HttpServletRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.config.Customizer; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; +import org.springframework.security.config.test.SpringTestContext; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtDecoderFactory; +import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Tests for {@link OidcUserRefreshedEventListener} with {@link OAuth2LoginConfigurer}. + * + * @author Steve Riesenberg + */ +public class OidcUserRefreshedEventListenerConfigurationTests { + + // @formatter:off + private static final ClientRegistration GOOGLE_CLIENT_REGISTRATION = CommonOAuth2Provider.GOOGLE + .getBuilder("google") + .clientId("clientId") + .clientSecret("clientSecret") + .build(); + // @formatter:on + + // @formatter:off + private static final ClientRegistration GITHUB_CLIENT_REGISTRATION = CommonOAuth2Provider.GITHUB + .getBuilder("github") + .clientId("clientId") + .clientSecret("clientSecret") + .build(); + // @formatter:on + + private static final String SUBJECT = "surfer-dude"; + + private static final String ACCESS_TOKEN_VALUE = "hang-ten"; + + private static final String REFRESH_TOKEN_VALUE = "surfs-up"; + + private static final String ID_TOKEN_VALUE = "beach-break"; + + public final SpringTestContext spring = new SpringTestContext(this); + + @Autowired + private SecurityContextRepository securityContextRepository; + + @Autowired + private OAuth2AuthorizedClientRepository authorizedClientRepository; + + @Autowired + private OAuth2AccessTokenResponseClient refreshTokenAccessTokenResponseClient; + + @Autowired + private JwtDecoder jwtDecoder; + + @Autowired + private OidcUserService oidcUserService; + + @Autowired + private OAuth2AuthorizedClientManager authorizedClientManager; + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + @BeforeEach + public void setUp() { + this.request = new MockHttpServletRequest("GET", ""); + this.request.setServletPath("/"); + this.response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.request, this.response)); + } + + @AfterEach + public void cleanUp() { + SecurityContextHolder.clearContext(); + RequestContextHolder.resetRequestAttributes(); + } + + @Test + public void authorizeWhenAccessTokenResponseMissingOpenidScopeThenOidcUserNotRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = createAccessTokenResponse(); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + OAuth2AuthenticationToken authentication = createAuthenticationToken(GOOGLE_CLIENT_REGISTRATION); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + verifyNoInteractions(this.securityContextRepository, this.jwtDecoder, this.oidcUserService); + } + + @Test + public void authorizeWhenAccessTokenResponseMissingIdTokenThenOidcUserNotRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.oidcAccessTokenResponse() + .build(); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + OAuth2AuthenticationToken authentication = createAuthenticationToken(GOOGLE_CLIENT_REGISTRATION); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + verifyNoInteractions(this.securityContextRepository, this.jwtDecoder, this.oidcUserService); + } + + @Test + public void authorizeWhenAuthenticationIsNotOAuth2ThenOidcUserNotRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = createAccessTokenResponse(OidcScopes.OPENID); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken(SUBJECT, null); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + SecurityContextHolder.setContext(securityContext); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + verifyNoInteractions(this.securityContextRepository, this.jwtDecoder, this.oidcUserService); + } + + @Test + public void authorizeWhenAuthenticationIsCustomThenOidcUserNotRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = createAccessTokenResponse(OidcScopes.OPENID); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + OidcUser oidcUser = createOidcUser(); + OAuth2AuthenticationToken authentication = new CustomOAuth2AuthenticationToken(oidcUser, + oidcUser.getAuthorities(), GOOGLE_CLIENT_REGISTRATION.getRegistrationId()); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + SecurityContextHolder.setContext(securityContext); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + verifyNoInteractions(this.securityContextRepository, this.jwtDecoder, this.oidcUserService); + } + + @Test + public void authorizeWhenPrincipalIsOAuth2UserThenOidcUserNotRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = createAccessTokenResponse(OidcScopes.OPENID); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + Map attributes = Map.of(StandardClaimNames.SUB, SUBJECT); + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("OAUTH2_USER"), attributes, + StandardClaimNames.SUB); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(oauth2User, + oauth2User.getAuthorities(), GOOGLE_CLIENT_REGISTRATION.getRegistrationId()); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + SecurityContextHolder.setContext(securityContext); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + verifyNoInteractions(this.securityContextRepository, this.jwtDecoder, this.oidcUserService); + } + + @Test + public void authorizeWhenAuthenticationClientRegistrationIdDoesNotMatchThenOidcUserNotRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = createAccessTokenResponse(OidcScopes.OPENID); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + OAuth2AuthenticationToken authentication = createAuthenticationToken(GITHUB_CLIENT_REGISTRATION); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + SecurityContextHolder.setContext(securityContext); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + verifyNoInteractions(this.securityContextRepository, this.jwtDecoder, this.oidcUserService); + } + + @Test + public void authorizeWhenAccessTokenResponseIncludesIdTokenThenOidcUserRefreshed() { + this.spring.register(OAuth2LoginWithOAuth2ClientConfig.class).autowire(); + + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = createAccessTokenResponse(OidcScopes.OPENID); + Jwt jwt = createJwt(); + OidcUser oidcUser = createOidcUser(); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(HttpServletRequest.class))) + .willReturn(authorizedClient); + given(this.refreshTokenAccessTokenResponseClient.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + given(this.jwtDecoder.decode(anyString())).willReturn(jwt); + given(this.oidcUserService.loadUser(any(OidcUserRequest.class))).willReturn(oidcUser); + + OAuth2AuthenticationToken authentication = createAuthenticationToken(GOOGLE_CLIENT_REGISTRATION); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + SecurityContextHolder.setContext(securityContext); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(GOOGLE_CLIENT_REGISTRATION.getRegistrationId()) + .principal(authentication) + .build(); + OAuth2AuthorizedClient refreshedAuthorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(refreshedAuthorizedClient).isNotNull(); + assertThat(refreshedAuthorizedClient).isNotSameAs(authorizedClient); + assertThat(refreshedAuthorizedClient.getClientRegistration()).isEqualTo(GOOGLE_CLIENT_REGISTRATION); + assertThat(refreshedAuthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(refreshedAuthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); + + ArgumentCaptor refreshTokenGrantRequestCaptor = ArgumentCaptor + .forClass(OAuth2RefreshTokenGrantRequest.class); + ArgumentCaptor userRequestCaptor = ArgumentCaptor.forClass(OidcUserRequest.class); + ArgumentCaptor securityContextCaptor = ArgumentCaptor.forClass(SecurityContext.class); + verify(this.authorizedClientRepository).loadAuthorizedClient(GOOGLE_CLIENT_REGISTRATION.getRegistrationId(), + authentication, this.request); + verify(this.authorizedClientRepository).saveAuthorizedClient(refreshedAuthorizedClient, authentication, + this.request, this.response); + verify(this.refreshTokenAccessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestCaptor.capture()); + verify(this.jwtDecoder).decode(jwt.getTokenValue()); + verify(this.oidcUserService).loadUser(userRequestCaptor.capture()); + verify(this.securityContextRepository).saveContext(securityContextCaptor.capture(), eq(this.request), + eq(this.response)); + verifyNoMoreInteractions(this.authorizedClientRepository, this.jwtDecoder, this.oidcUserService, + this.securityContextRepository); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = refreshTokenGrantRequestCaptor.getValue(); + assertThat(refreshTokenGrantRequest.getClientRegistration()) + .isEqualTo(authorizedClient.getClientRegistration()); + assertThat(refreshTokenGrantRequest.getRefreshToken()).isEqualTo(authorizedClient.getRefreshToken()); + assertThat(refreshTokenGrantRequest.getAccessToken()).isEqualTo(authorizedClient.getAccessToken()); + + OidcUserRequest userRequest = userRequestCaptor.getValue(); + assertThat(userRequest.getClientRegistration()).isEqualTo(GOOGLE_CLIENT_REGISTRATION); + assertThat(userRequest.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(userRequest.getIdToken().getTokenValue()).isEqualTo(jwt.getTokenValue()); + + SecurityContext refreshedSecurityContext = securityContextCaptor.getValue(); + assertThat(refreshedSecurityContext).isNotNull(); + assertThat(refreshedSecurityContext).isNotSameAs(securityContext); + assertThat(refreshedSecurityContext).isSameAs(SecurityContextHolder.getContext()); + assertThat(refreshedSecurityContext.getAuthentication()).isInstanceOf(OAuth2AuthenticationToken.class); + assertThat(refreshedSecurityContext.getAuthentication()).isNotSameAs(authentication); + assertThat(refreshedSecurityContext.getAuthentication().getPrincipal()).isInstanceOf(OidcUser.class); + assertThat(refreshedSecurityContext.getAuthentication().getPrincipal()) + .isNotSameAs(authentication.getPrincipal()); + } + + private OAuth2AuthorizedClient createAuthorizedClient() { + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(30, ChronoUnit.SECONDS); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, ACCESS_TOKEN_VALUE, + issuedAt, expiresAt, Set.of(OidcScopes.OPENID)); + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(REFRESH_TOKEN_VALUE, issuedAt); + + return new OAuth2AuthorizedClient(GOOGLE_CLIENT_REGISTRATION, SUBJECT, accessToken, refreshToken); + } + + private OAuth2AccessTokenResponse createAccessTokenResponse(String... scope) { + Set scopes = Set.of(scope); + Map additionalParameters = new HashMap<>(); + if (scopes.contains(OidcScopes.OPENID)) { + additionalParameters.put(OidcParameterNames.ID_TOKEN, ID_TOKEN_VALUE); + } + + return OAuth2AccessTokenResponse.withToken(ACCESS_TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .scopes(scopes) + .refreshToken(REFRESH_TOKEN_VALUE) + .expiresIn(60L) + .additionalParameters(additionalParameters) + .build(); + } + + private Jwt createJwt() { + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.MINUTES); + return TestJwts.jwt() + .subject(SUBJECT) + .tokenValue(ID_TOKEN_VALUE) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .build(); + } + + private OidcUser createOidcUser() { + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.SUB, SUBJECT); + claims.put(IdTokenClaimNames.ISS, "issuer"); + claims.put(IdTokenClaimNames.AUD, List.of("audience1", "audience2")); + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.MINUTES); + OidcIdToken idToken = new OidcIdToken(ID_TOKEN_VALUE, issuedAt, expiresAt, claims); + + return new DefaultOidcUser(AuthorityUtils.createAuthorityList("OIDC_USER"), idToken); + } + + private OAuth2AuthenticationToken createAuthenticationToken(ClientRegistration clientRegistration) { + OidcUser oidcUser = createOidcUser(); + return new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), + clientRegistration.getRegistrationId()); + } + + @Configuration + @EnableWebSecurity + static class OAuth2LoginWithOAuth2ClientConfig { + + @Bean + SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeHttpRequests((authorize) -> authorize + .anyRequest().authenticated() + ) + .securityContext((securityContext) -> securityContext + .securityContextRepository(this.securityContextRepository()) + ) + .oauth2Login(Customizer.withDefaults()) + .oauth2Client(Customizer.withDefaults()); + // @formatter:on + return http.build(); + } + + @Bean + SecurityContextRepository securityContextRepository() { + return mock(SecurityContextRepository.class); + } + + @Bean + ClientRegistrationRepository clientRegistrationRepository() { + return mock(ClientRegistrationRepository.class); + } + + @Bean + OAuth2AuthorizedClientRepository authorizedClientRepository() { + return mock(OAuth2AuthorizedClientRepository.class); + } + + @Bean + @SuppressWarnings("unchecked") + OAuth2AccessTokenResponseClient refreshTokenAccessTokenResponseClient() { + return mock(OAuth2AccessTokenResponseClient.class); + } + + @Bean + JwtDecoder jwtDecoder() { + return mock(JwtDecoder.class); + } + + @Bean + JwtDecoderFactory jwtDecoderFactory() { + return (clientRegistration) -> jwtDecoder(); + } + + @Bean + OidcUserService oidcUserService() { + return mock(OidcUserService.class); + } + + } + + private static final class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken { + + CustomOAuth2AuthenticationToken(OAuth2User principal, Collection authorities, + String authorizedClientRegistrationId) { + super(principal, authorities, authorizedClientRegistrationId); + } + + } + +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java new file mode 100644 index 00000000000..6b8f82a8bd0 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configurers.oauth2.client; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.oidc.authentication.event.OidcUserRefreshedEvent; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers; +import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Tests for {@link OidcUserRefreshedEventListener}. + * + * @author Steve Riesenberg + */ +public class OidcUserRefreshedEventListenerTests { + + private OidcUserRefreshedEventListener eventListener; + + private SecurityContextRepository securityContextRepository; + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + @BeforeEach + public void setUp() { + this.securityContextRepository = mock(SecurityContextRepository.class); + this.eventListener = new OidcUserRefreshedEventListener(); + this.eventListener.setSecurityContextRepository(this.securityContextRepository); + + this.request = new MockHttpServletRequest("GET", ""); + this.request.setServletPath("/"); + this.response = new MockHttpServletResponse(); + } + + @AfterEach + public void cleanUp() { + SecurityContextHolder.clearContext(); + RequestContextHolder.resetRequestAttributes(); + } + + @Test + public void setSecurityContextHolderStrategyWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setSecurityContextHolderStrategy(null)) + .withMessage("securityContextHolderStrategy cannot be null"); + } + + @Test + public void setSecurityContextRepositoryWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setSecurityContextRepository(null)) + .withMessage("securityContextRepository cannot be null"); + } + + @Test + public void onApplicationEventWhenRequestAttributesSetThenSecurityContextSaved() { + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.request, this.response)); + + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.oidcAccessTokenResponse() + .build(); + OidcUser oldOidcUser = TestOidcUsers.create(); + OidcUser newOidcUser = TestOidcUsers.create(); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(newOidcUser, + newOidcUser.getAuthorities(), "test"); + OidcUserRefreshedEvent event = new OidcUserRefreshedEvent(accessTokenResponse, oldOidcUser, newOidcUser, + authentication); + this.eventListener.onApplicationEvent(event); + + ArgumentCaptor securityContextCaptor = ArgumentCaptor.forClass(SecurityContext.class); + verify(this.securityContextRepository).saveContext(securityContextCaptor.capture(), eq(this.request), + eq(this.response)); + verifyNoMoreInteractions(this.securityContextRepository); + + SecurityContext securityContext = securityContextCaptor.getValue(); + assertThat(securityContext).isNotNull(); + assertThat(securityContext).isSameAs(SecurityContextHolder.getContext()); + assertThat(securityContext.getAuthentication()).isSameAs(authentication); + } + + @Test + public void onApplicationEventWhenRequestAttributesNotSetThenSecurityContextNotSaved() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.oidcAccessTokenResponse() + .build(); + OidcUser oldOidcUser = TestOidcUsers.create(); + OidcUser newOidcUser = TestOidcUsers.create(); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(newOidcUser, + newOidcUser.getAuthorities(), "test"); + OidcUserRefreshedEvent event = new OidcUserRefreshedEvent(accessTokenResponse, oldOidcUser, newOidcUser, + authentication); + OidcUserRefreshedEventListener eventListener = new OidcUserRefreshedEventListener(); + eventListener.setSecurityContextRepository(this.securityContextRepository); + eventListener.onApplicationEvent(event); + verifyNoInteractions(this.securityContextRepository); + + SecurityContext securityContext = SecurityContextHolder.getContext(); + assertThat(securityContext).isNotNull(); + assertThat(securityContext.getAuthentication()).isSameAs(authentication); + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java index bcd130063e6..316b2d5c131 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -387,6 +387,7 @@ public RefreshTokenGrantBuilder accessTokenResponseClient( * refreshed. * @param eventPublisher the {@link ApplicationEventPublisher} * @return the {@link RefreshTokenGrantBuilder} + * @since 6.5 */ public RefreshTokenGrantBuilder eventPublisher(ApplicationEventPublisher eventPublisher) { this.eventPublisher = eventPublisher; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index 17dc2ad16b9..e306c7e9fe6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; -import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent; +import org.springframework.security.oauth2.client.event.OAuth2AuthorizedClientRefreshedEvent; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Token; @@ -51,7 +51,7 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); - private ApplicationEventPublisher eventPublisher; + private ApplicationEventPublisher applicationEventPublisher; private Duration clockSkew = Duration.ofSeconds(60); @@ -98,16 +98,17 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { authorizedClient.getRefreshToken(), scopes); OAuth2AccessTokenResponse tokenResponse = getTokenResponse(authorizedClient, refreshTokenGrantRequest); - OAuth2AuthorizedClient updatedOAuth2AuthorizedClient = new OAuth2AuthorizedClient( + OAuth2AuthorizedClient refreshedAuthorizedClient = new OAuth2AuthorizedClient( authorizedClient.getClientRegistration(), context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); - if (this.eventPublisher != null) { - this.eventPublisher - .publishEvent(new OAuth2TokenRefreshedEvent(this, updatedOAuth2AuthorizedClient, tokenResponse)); + if (this.applicationEventPublisher != null) { + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + tokenResponse, refreshedAuthorizedClient); + this.applicationEventPublisher.publishEvent(authorizedClientRefreshedEvent); } - return updatedOAuth2AuthorizedClient; + return refreshedAuthorizedClient; } private OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizedClient authorizedClient, @@ -166,7 +167,8 @@ public void setClock(Clock clock) { @Override public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { - this.eventPublisher = applicationEventPublisher; + Assert.notNull(applicationEventPublisher, "applicationEventPublisher cannot be null"); + this.applicationEventPublisher = applicationEventPublisher; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2AuthorizedClientRefreshedEvent.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2AuthorizedClientRefreshedEvent.java new file mode 100644 index 00000000000..aa70bf76f35 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2AuthorizedClientRefreshedEvent.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.event; + +import java.io.Serial; + +import org.springframework.context.ApplicationEvent; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; + +/** + * An event that is published when an {@link OAuth2AuthorizedClient} is refreshed as a + * result of using a {@code refresh_token} to obtain an OAuth 2.0 Access Token Response. + * + * @author Steve Riesenberg + * @since 6.5 + */ +public final class OAuth2AuthorizedClientRefreshedEvent extends ApplicationEvent { + + @Serial + private static final long serialVersionUID = -2178028089321556476L; + + private final OAuth2AuthorizedClient authorizedClient; + + /** + * Creates a new instance with the provided parameters. + * @param accessTokenResponse the {@link OAuth2AccessTokenResponse} that triggered the + * event + * @param authorizedClient the refreshed {@link OAuth2AuthorizedClient} + */ + public OAuth2AuthorizedClientRefreshedEvent(OAuth2AccessTokenResponse accessTokenResponse, + OAuth2AuthorizedClient authorizedClient) { + super(accessTokenResponse); + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + this.authorizedClient = authorizedClient; + } + + /** + * Returns the {@link OAuth2AccessTokenResponse} that triggered the event. + * @return the access token response + */ + public OAuth2AccessTokenResponse getAccessTokenResponse() { + return (OAuth2AccessTokenResponse) this.getSource(); + } + + /** + * Returns the refreshed {@link OAuth2AuthorizedClient}. + * @return the authorized client + */ + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java deleted file mode 100644 index f92091d4cd7..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright 2002-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.client.event; - -import org.springframework.context.ApplicationEvent; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; - -/** - * An event that is published when an OAuth2 access token is refreshed. - */ -public class OAuth2TokenRefreshedEvent extends ApplicationEvent { - - private final OAuth2AuthorizedClient authorizedClient; - - private final OAuth2AccessTokenResponse accessTokenResponse; - - public OAuth2TokenRefreshedEvent(Object source, OAuth2AuthorizedClient authorizedClient, - OAuth2AccessTokenResponse accessTokenResponse) { - super(source); - this.authorizedClient = authorizedClient; - this.accessTokenResponse = accessTokenResponse; - } - - public OAuth2AuthorizedClient getAuthorizedClient() { - return this.authorizedClient; - } - - public OAuth2AccessTokenResponse getAccessTokenResponse() { - return this.accessTokenResponse; - } - -} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListener.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListener.java new file mode 100644 index 00000000000..b6f8c45b499 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListener.java @@ -0,0 +1,219 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.authentication; + +import java.util.Collection; +import java.util.Map; + +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; +import org.springframework.context.ApplicationListener; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; +import org.springframework.security.oauth2.client.event.OAuth2AuthorizedClientRefreshedEvent; +import org.springframework.security.oauth2.client.oidc.authentication.event.OidcUserRefreshedEvent; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtDecoderFactory; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * An {@link ApplicationListener} that listens for events of type + * {@link OAuth2AuthorizedClientRefreshedEvent} and publishes an event of type + * {@link OidcUserRefreshedEvent} in order to refresh an {@link OidcUser}. + * + * @author Steve Riesenberg + * @since 6.5 + * @see org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider + * @see OAuth2AuthorizedClientRefreshedEvent + * @see OidcUserRefreshedEvent + */ +public final class OidcAuthorizedClientRefreshedEventListener + implements ApplicationEventPublisherAware, ApplicationListener { + + private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token"; + + private static final String INVALID_NONCE_ERROR_CODE = "invalid_nonce"; + + private OAuth2UserService userService = new OidcUserService(); + + private JwtDecoderFactory jwtDecoderFactory = new OidcIdTokenDecoderFactory(); + + private GrantedAuthoritiesMapper authoritiesMapper = (authorities) -> authorities; + + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + + private ApplicationEventPublisher applicationEventPublisher; + + @Override + public void onApplicationEvent(OAuth2AuthorizedClientRefreshedEvent event) { + if (this.applicationEventPublisher == null) { + return; + } + + // The response must contain the openid scope + OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse(); + if (!accessTokenResponse.getAccessToken().getScopes().contains(OidcScopes.OPENID)) { + return; + } + + // The response must contain an id_token + Map additionalParameters = accessTokenResponse.getAdditionalParameters(); + if (!StringUtils.hasText((String) additionalParameters.get(OidcParameterNames.ID_TOKEN))) { + return; + } + + // The current authentication must be an OAuth2AuthenticationToken + Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication(); + if (!(authentication instanceof OAuth2AuthenticationToken authenticationToken) + || authenticationToken.getClass() != OAuth2AuthenticationToken.class) { + // This event listener only handles the default authentication result. If the + // application customizes the authentication result, then a custom event + // handler should be provided. + return; + } + + // The current principal must be an OidcUser + if (!(authenticationToken.getPrincipal() instanceof OidcUser existingOidcUser)) { + return; + } + + // The registrationId must match the one used to log in + ClientRegistration clientRegistration = event.getAuthorizedClient().getClientRegistration(); + if (!authenticationToken.getAuthorizedClientRegistrationId().equals(clientRegistration.getRegistrationId())) { + return; + } + + // Refresh the OidcUser and send a user refreshed event + OidcIdToken idToken = createOidcToken(clientRegistration, accessTokenResponse); + validateNonce(existingOidcUser, idToken); + OidcUserRequest userRequest = new OidcUserRequest(clientRegistration, accessTokenResponse.getAccessToken(), + idToken, additionalParameters); + OidcUser oidcUser = this.userService.loadUser(userRequest); + Collection mappedAuthorities = this.authoritiesMapper + .mapAuthorities(oidcUser.getAuthorities()); + OAuth2AuthenticationToken authenticationResult = new OAuth2AuthenticationToken(oidcUser, mappedAuthorities, + clientRegistration.getRegistrationId()); + authenticationResult.setDetails(authenticationToken.getDetails()); + OidcUserRefreshedEvent oidcUserRefreshedEvent = new OidcUserRefreshedEvent(accessTokenResponse, + existingOidcUser, oidcUser, authenticationResult); + this.applicationEventPublisher.publishEvent(oidcUserRefreshedEvent); + } + + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + + /** + * Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature + * verification. The factory returns a {@link JwtDecoder} associated to the provided + * {@link ClientRegistration}. + * @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken} + * signature verification + */ + public void setJwtDecoderFactory(JwtDecoderFactory jwtDecoderFactory) { + Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null"); + this.jwtDecoderFactory = jwtDecoderFactory; + } + + /** + * Sets the {@link OAuth2UserService} used for obtaining the user attributes of the + * End-User from the UserInfo Endpoint. + * @param userService the service used for obtaining the user attributes of the + * End-User from the UserInfo Endpoint + */ + public void setUserService(OAuth2UserService userService) { + Assert.notNull(userService, "userService cannot be null"); + this.userService = userService; + } + + /** + * Sets the {@link GrantedAuthoritiesMapper} used for mapping + * {@link OidcUser#getAuthorities()}} to a new set of authorities which will be + * associated to the {@link OAuth2LoginAuthenticationToken}. + * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the + * user's authorities + */ + public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { + Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null"); + this.authoritiesMapper = authoritiesMapper; + } + + /** + * Sets the {@link ApplicationEventPublisher} to be used. + * @param applicationEventPublisher event publisher to be used + */ + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { + Assert.notNull(applicationEventPublisher, "applicationEventPublisher cannot be null"); + this.applicationEventPublisher = applicationEventPublisher; + } + + private OidcIdToken createOidcToken(ClientRegistration clientRegistration, + OAuth2AccessTokenResponse accessTokenResponse) { + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); + Jwt jwt = getJwt(accessTokenResponse, jwtDecoder); + return new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims()); + } + + private Jwt getJwt(OAuth2AccessTokenResponse accessTokenResponse, JwtDecoder jwtDecoder) { + try { + Map parameters = accessTokenResponse.getAdditionalParameters(); + return jwtDecoder.decode((String) parameters.get(OidcParameterNames.ID_TOKEN)); + } + catch (JwtException ex) { + OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null); + throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex); + } + } + + private void validateNonce(OidcUser existingOidcUser, OidcIdToken idToken) { + if (!StringUtils.hasText(idToken.getNonce())) { + return; + } + + if (!idToken.getNonce().equals(existingOidcUser.getNonce())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java deleted file mode 100644 index d1af1a4f485..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Copyright 2002-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.client.oidc.authentication; - -import java.util.Map; - -import org.springframework.context.ApplicationListener; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContext; -import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.core.context.SecurityContextHolderStrategy; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; -import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.oidc.OidcIdToken; -import org.springframework.security.oauth2.core.oidc.OidcScopes; -import org.springframework.security.oauth2.core.oidc.StandardClaimNames; -import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; -import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; -import org.springframework.security.oauth2.core.oidc.user.OidcUser; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtDecoder; -import org.springframework.security.oauth2.jwt.JwtDecoderFactory; -import org.springframework.security.oauth2.jwt.JwtException; -import org.springframework.util.Assert; - -/** - * An {@link ApplicationListener} that listens for {@link OAuth2TokenRefreshedEvent}s - */ -public class RefreshOidcIdTokenHandler implements ApplicationListener { - - private static final String MISSING_ID_TOKEN_ERROR_CODE = "missing_id_token"; - - private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token"; - - private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder - .getContextHolderStrategy(); - - private JwtDecoderFactory jwtDecoderFactory = new OidcIdTokenDecoderFactory(); - - @Override - public void onApplicationEvent(OAuth2TokenRefreshedEvent event) { - OAuth2AuthorizedClient authorizedClient = event.getAuthorizedClient(); - - if (!authorizedClient.getClientRegistration().getScopes().contains(OidcScopes.OPENID)) { - return; - } - - Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication(); - if (!(authentication instanceof OAuth2AuthenticationToken oauth2Authentication)) { - return; - } - if (!(authentication.getPrincipal() instanceof DefaultOidcUser defaultOidcUser)) { - return; - } - - OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse(); - - String idToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN); - if (idToken == null || idToken.isBlank()) { - OAuth2Error missingIdTokenError = new OAuth2Error(MISSING_ID_TOKEN_ERROR_CODE, - "ID token is missing in the token response", null); - throw new OAuth2AuthenticationException(missingIdTokenError, missingIdTokenError.toString()); - } - - ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); - OidcIdToken refreshedOidcToken = createOidcToken(clientRegistration, accessTokenResponse); - updateSecurityContext(oauth2Authentication, defaultOidcUser, refreshedOidcToken); - } - - /** - * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use - * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. - */ - public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { - this.securityContextHolderStrategy = securityContextHolderStrategy; - } - - /** - * Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature - * verification. The factory returns a {@link JwtDecoder} associated to the provided - * {@link ClientRegistration}. - * @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken} - * signature verification - */ - public final void setJwtDecoderFactory(JwtDecoderFactory jwtDecoderFactory) { - Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null"); - this.jwtDecoderFactory = jwtDecoderFactory; - } - - private void updateSecurityContext(OAuth2AuthenticationToken oauth2Authentication, DefaultOidcUser defaultOidcUser, - OidcIdToken refreshedOidcToken) { - OidcUser oidcUser = new DefaultOidcUser(defaultOidcUser.getAuthorities(), refreshedOidcToken, - defaultOidcUser.getUserInfo(), StandardClaimNames.SUB); - - SecurityContext context = this.securityContextHolderStrategy.createEmptyContext(); - context.setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), - oauth2Authentication.getAuthorizedClientRegistrationId())); - - this.securityContextHolderStrategy.setContext(context); - } - - private OidcIdToken createOidcToken(ClientRegistration clientRegistration, - OAuth2AccessTokenResponse accessTokenResponse) { - JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); - Jwt jwt = getJwt(accessTokenResponse, jwtDecoder); - return new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims()); - } - - private Jwt getJwt(OAuth2AccessTokenResponse accessTokenResponse, JwtDecoder jwtDecoder) { - try { - Map parameters = accessTokenResponse.getAdditionalParameters(); - return jwtDecoder.decode((String) parameters.get(OidcParameterNames.ID_TOKEN)); - } - catch (JwtException ex) { - OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null); - throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex); - } - } - -} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/event/OidcUserRefreshedEvent.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/event/OidcUserRefreshedEvent.java new file mode 100644 index 00000000000..eda81717396 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/event/OidcUserRefreshedEvent.java @@ -0,0 +1,98 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.authentication.event; + +import java.io.Serial; + +import org.springframework.context.ApplicationEvent; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.util.Assert; + +/** + * An event that is published when an {@link OidcUser} is refreshed as a result of using a + * {@code refresh_token} to obtain an OAuth 2.0 Access Token Response that contains an + * {@code id_token}. + * + * @author Steve Riesenberg + * @since 6.5 + * @see org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider + */ +public final class OidcUserRefreshedEvent extends ApplicationEvent { + + @Serial + private static final long serialVersionUID = 2657442604286019694L; + + private final OidcUser oldOidcUser; + + private final OidcUser newOidcUser; + + private final Authentication authentication; + + /** + * Creates a new instance with the provided parameters. + * @param accessTokenResponse the {@link OAuth2AccessTokenResponse} that triggered the + * event + * @param oldOidcUser the original {@link OidcUser} + * @param newOidcUser the refreshed {@link OidcUser} + * @param authentication the authentication result + */ + public OidcUserRefreshedEvent(OAuth2AccessTokenResponse accessTokenResponse, OidcUser oldOidcUser, + OidcUser newOidcUser, Authentication authentication) { + super(accessTokenResponse); + Assert.notNull(oldOidcUser, "oldOidcUser cannot be null"); + Assert.notNull(newOidcUser, "newOidcUser cannot be null"); + Assert.notNull(authentication, "authentication cannot be null"); + this.oldOidcUser = oldOidcUser; + this.newOidcUser = newOidcUser; + this.authentication = authentication; + } + + /** + * Returns the {@link OAuth2AccessTokenResponse} that triggered the event. + * @return the access token response + */ + public OAuth2AccessTokenResponse getAccessTokenResponse() { + return (OAuth2AccessTokenResponse) this.getSource(); + } + + /** + * Returns the original {@link OidcUser}. + * @return the original user + */ + public OidcUser getOldOidcUser() { + return this.oldOidcUser; + } + + /** + * Returns the refreshed {@link OidcUser}. + * @return the refreshed user + */ + public OidcUser getNewOidcUser() { + return this.newOidcUser; + } + + /** + * Returns the authentication result. + * @return the authentication result + */ + public Authentication getAuthentication() { + return this.authentication; + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java index dc4c4232004..5b7fc551f36 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; -import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent; +import org.springframework.security.oauth2.client.event.OAuth2AuthorizedClientRefreshedEvent; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -297,7 +297,7 @@ private static class OAuth2TokenRefreshedAwareEventPublisher implements Applicat @Override public void publishEvent(Object event) { - if (OAuth2TokenRefreshedEvent.class.isAssignableFrom(event.getClass())) { + if (OAuth2AuthorizedClientRefreshedEvent.class.isAssignableFrom(event.getClass())) { this.flag = true; } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListenerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListenerTests.java new file mode 100644 index 00000000000..64a94a8f37b --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizedClientRefreshedEventListenerTests.java @@ -0,0 +1,420 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.authentication; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.event.OAuth2AuthorizedClientRefreshedEvent; +import org.springframework.security.oauth2.client.oidc.authentication.event.OidcUserRefreshedEvent; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.security.oauth2.jwt.TestJwts; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Tests for {@link OidcAuthorizedClientRefreshedEventListener}. + * + * @author Steve Riesenberg + */ +public class OidcAuthorizedClientRefreshedEventListenerTests { + + private static final String SUBJECT = "surfer-dude"; + + private static final String ACCESS_TOKEN_VALUE = "hang-ten"; + + private static final String REFRESH_TOKEN_VALUE = "surfs-up"; + + private static final String ID_TOKEN_VALUE = "beach-break"; + + private OidcAuthorizedClientRefreshedEventListener eventListener; + + private SecurityContextHolderStrategy securityContextHolderStrategy; + + private JwtDecoder jwtDecoder; + + private OidcUserService userService; + + private ApplicationEventPublisher applicationEventPublisher; + + private ClientRegistration clientRegistration; + + private OAuth2AuthorizedClient authorizedClient; + + private OAuth2AccessTokenResponse accessTokenResponse; + + private Jwt jwt; + + private OidcUser oidcUser; + + @BeforeEach + public void setUp() { + this.jwtDecoder = mock(JwtDecoder.class); + this.userService = mock(OidcUserService.class); + this.securityContextHolderStrategy = mock(SecurityContextHolderStrategy.class); + this.applicationEventPublisher = mock(ApplicationEventPublisher.class); + + this.eventListener = new OidcAuthorizedClientRefreshedEventListener(); + this.eventListener.setUserService(this.userService); + this.eventListener.setJwtDecoderFactory((clientRegistration) -> this.jwtDecoder); + this.eventListener.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); + this.eventListener.setApplicationEventPublisher(this.applicationEventPublisher); + + this.clientRegistration = TestClientRegistrations.clientRegistration().scope(OidcScopes.OPENID).build(); + this.authorizedClient = createAuthorizedClient(this.clientRegistration); + this.accessTokenResponse = createAccessTokenResponse(OidcScopes.OPENID); + this.jwt = createJwt(); + this.oidcUser = createOidcUser(); + } + + @Test + public void setSecurityContextHolderStrategyWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setSecurityContextHolderStrategy(null)) + .withMessage("securityContextHolderStrategy cannot be null"); + } + + @Test + public void setJwtDecoderFactoryWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setJwtDecoderFactory(null)) + .withMessage("jwtDecoderFactory cannot be null"); + } + + @Test + public void setUserServiceWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setUserService(null)) + .withMessage("userService cannot be null"); + } + + @Test + public void setAuthoritiesMapperWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setAuthoritiesMapper(null)) + .withMessage("authoritiesMapper cannot be null"); + } + + @Test + public void setApplicationEventPublisherWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.eventListener.setApplicationEventPublisher(null)) + .withMessage("applicationEventPublisher cannot be null"); + } + + @Test + public void onApplicationEventWhenAccessTokenResponseMissingIdTokenThenOidcUserRefreshedEventNotPublished() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .scopes(Set.of(OidcScopes.OPENID)) + .build(); + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + verifyNoInteractions(this.securityContextHolderStrategy, this.jwtDecoder, this.userService, + this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenAccessTokenResponseMissingOpenidScopeThenOidcUserRefreshedEventNotPublished() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.oidcAccessTokenResponse() + .scopes(Set.of()) + .build(); + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + verifyNoInteractions(this.securityContextHolderStrategy, this.jwtDecoder, this.userService, + this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenAuthenticationIsNotOAuth2ThenOidcUserRefreshedEventNotPublished() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken(SUBJECT, null); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + + verify(this.securityContextHolderStrategy).getContext(); + verifyNoMoreInteractions(this.securityContextHolderStrategy); + verifyNoInteractions(this.jwtDecoder, this.userService, this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenAuthenticationIsCustomThenOidcUserRefreshedEventNotPublished() { + OAuth2AuthenticationToken authentication = new CustomOAuth2AuthenticationToken(this.oidcUser, + this.oidcUser.getAuthorities(), this.clientRegistration.getRegistrationId()); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + + verify(this.securityContextHolderStrategy).getContext(); + verifyNoMoreInteractions(this.securityContextHolderStrategy); + verifyNoInteractions(this.jwtDecoder, this.userService, this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenPrincipalIsOAuth2UserThenOidcUserRefreshedEventNotPublished() { + Map attributes = Map.of(StandardClaimNames.SUB, SUBJECT); + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("OAUTH2_USER"), attributes, + StandardClaimNames.SUB); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(oauth2User, + oauth2User.getAuthorities(), this.clientRegistration.getRegistrationId()); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + + verify(this.securityContextHolderStrategy).getContext(); + verifyNoMoreInteractions(this.securityContextHolderStrategy); + verifyNoInteractions(this.jwtDecoder, this.userService, this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenClientRegistrationIdDoesNotMatchThenOidcUserRefreshedEventNotPublished() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .registrationId("test") + .build(); + OAuth2AuthenticationToken authentication = createAuthenticationToken(clientRegistration); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + + verify(this.securityContextHolderStrategy).getContext(); + verifyNoMoreInteractions(this.securityContextHolderStrategy); + verifyNoInteractions(this.jwtDecoder, this.userService, this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenAccessTokenResponseIncludesIdTokenThenPublishOidcUserRefreshedEvent() { + OAuth2AuthenticationToken authentication = createAuthenticationToken(this.clientRegistration); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + given(this.jwtDecoder.decode(anyString())).willReturn(this.jwt); + given(this.userService.loadUser(any(OidcUserRequest.class))).willReturn(this.oidcUser); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + + ArgumentCaptor userRequestCaptor = ArgumentCaptor.forClass(OidcUserRequest.class); + ArgumentCaptor userRefreshedEventCaptor = ArgumentCaptor + .forClass(OidcUserRefreshedEvent.class); + verify(this.securityContextHolderStrategy).getContext(); + verify(this.jwtDecoder).decode(this.jwt.getTokenValue()); + verify(this.userService).loadUser(userRequestCaptor.capture()); + verify(this.applicationEventPublisher).publishEvent(userRefreshedEventCaptor.capture()); + verifyNoMoreInteractions(this.securityContextHolderStrategy, this.jwtDecoder, this.userService, + this.applicationEventPublisher); + + OidcUserRequest userRequest = userRequestCaptor.getValue(); + assertThat(userRequest.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(userRequest.getAccessToken()).isSameAs(this.accessTokenResponse.getAccessToken()); + assertThat(userRequest.getIdToken().getTokenValue()).isEqualTo(this.jwt.getTokenValue()); + + OidcUserRefreshedEvent userRefreshedEvent = userRefreshedEventCaptor.getValue(); + assertThat(userRefreshedEvent.getAccessTokenResponse()).isSameAs(this.accessTokenResponse); + assertThat(userRefreshedEvent.getOldOidcUser()).isSameAs(authentication.getPrincipal()); + assertThat(userRefreshedEvent.getNewOidcUser()).isSameAs(this.oidcUser); + assertThat(userRefreshedEvent.getAuthentication()).isNotSameAs(authentication); + assertThat(userRefreshedEvent.getAuthentication()).isInstanceOf(OAuth2AuthenticationToken.class); + + OAuth2AuthenticationToken authenticationResult = (OAuth2AuthenticationToken) userRefreshedEvent + .getAuthentication(); + assertThat(authenticationResult.getPrincipal()).isEqualTo(this.oidcUser); + assertThat(authenticationResult.getAuthorities()).containsExactlyElementsOf(this.oidcUser.getAuthorities()); + assertThat(authenticationResult.getAuthorizedClientRegistrationId()) + .isEqualTo(this.clientRegistration.getRegistrationId()); + } + + @Test + public void onApplicationEventWhenIdTokenNonceDoesNotMatchThenThrowsOAuth2AuthenticationException() { + Jwt jwt = TestJwts.jwt().claim(IdTokenClaimNames.NONCE, "invalid").build(); + OAuth2AuthenticationToken authentication = createAuthenticationToken(this.clientRegistration); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + given(this.jwtDecoder.decode(anyString())).willReturn(jwt); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo("invalid_nonce"); + verify(this.securityContextHolderStrategy).getContext(); + verify(this.jwtDecoder).decode(this.jwt.getTokenValue()); + verifyNoMoreInteractions(this.securityContextHolderStrategy, this.jwtDecoder); + verifyNoInteractions(this.userService, this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenInvalidIdTokenThenThrowsOAuth2AuthenticationException() { + OAuth2AuthenticationToken authentication = createAuthenticationToken(this.clientRegistration); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + given(this.jwtDecoder.decode(anyString())).willThrow(new JwtException("Invalid token")); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo("invalid_id_token"); + verify(this.securityContextHolderStrategy).getContext(); + verify(this.jwtDecoder).decode(this.jwt.getTokenValue()); + verifyNoMoreInteractions(this.securityContextHolderStrategy, this.jwtDecoder); + verifyNoInteractions(this.userService, this.applicationEventPublisher); + } + + @Test + public void onApplicationEventWhenCustomAuthoritiesMapperSetThenUsed() { + OAuth2AuthenticationToken authentication = createAuthenticationToken(this.clientRegistration); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + given(this.jwtDecoder.decode(anyString())).willReturn(this.jwt); + given(this.userService.loadUser(any(OidcUserRequest.class))).willReturn(this.oidcUser); + + GrantedAuthoritiesMapper grantedAuthoritiesMapper = mock(GrantedAuthoritiesMapper.class); + this.eventListener.setAuthoritiesMapper(grantedAuthoritiesMapper); + + OAuth2AuthorizedClientRefreshedEvent authorizedClientRefreshedEvent = new OAuth2AuthorizedClientRefreshedEvent( + this.accessTokenResponse, this.authorizedClient); + this.eventListener.onApplicationEvent(authorizedClientRefreshedEvent); + + verify(grantedAuthoritiesMapper).mapAuthorities(this.oidcUser.getAuthorities()); + verifyNoMoreInteractions(grantedAuthoritiesMapper); + } + + private static OAuth2AuthorizedClient createAuthorizedClient(ClientRegistration clientRegistration) { + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(30, ChronoUnit.SECONDS); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, ACCESS_TOKEN_VALUE, + issuedAt, expiresAt, clientRegistration.getScopes()); + + return new OAuth2AuthorizedClient(clientRegistration, SUBJECT, accessToken); + } + + private static OAuth2AccessTokenResponse createAccessTokenResponse(String... scope) { + Set scopes = Set.of(scope); + Map additionalParameters = new HashMap<>(); + if (scopes.contains(OidcScopes.OPENID)) { + additionalParameters.put(OidcParameterNames.ID_TOKEN, ID_TOKEN_VALUE); + } + + return OAuth2AccessTokenResponse.withToken(ACCESS_TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .scopes(scopes) + .refreshToken(REFRESH_TOKEN_VALUE) + .expiresIn(60L) + .additionalParameters(additionalParameters) + .build(); + } + + private static Jwt createJwt() { + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.MINUTES); + return TestJwts.jwt() + .subject(SUBJECT) + .tokenValue(ID_TOKEN_VALUE) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .claim(OidcParameterNames.NONCE, "nonce") + .build(); + } + + private static OidcUser createOidcUser() { + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.SUB, SUBJECT); + claims.put(IdTokenClaimNames.ISS, "issuer"); + claims.put(IdTokenClaimNames.AUD, List.of("audience1", "audience2")); + claims.put(IdTokenClaimNames.NONCE, "nonce"); + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.MINUTES); + OidcIdToken idToken = new OidcIdToken(ID_TOKEN_VALUE, issuedAt, expiresAt, claims); + + return new DefaultOidcUser(AuthorityUtils.createAuthorityList("OIDC_USER"), idToken); + } + + private static OAuth2AuthenticationToken createAuthenticationToken(ClientRegistration clientRegistration) { + OidcUser oidcUser = createOidcUser(); + return new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), + clientRegistration.getRegistrationId()); + } + + private static final class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken { + + CustomOAuth2AuthenticationToken(OAuth2User principal, Collection authorities, + String authorizedClientRegistrationId) { + super(principal, authorities, authorizedClientRegistrationId); + } + + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandlerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandlerTests.java deleted file mode 100644 index 61ff5aae892..00000000000 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandlerTests.java +++ /dev/null @@ -1,284 +0,0 @@ -/* - * Copyright 2002-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.client.oidc.authentication; - -import java.time.Instant; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; - -import org.springframework.security.authentication.TestingAuthenticationToken; -import org.springframework.security.core.authority.AuthorityUtils; -import org.springframework.security.core.context.SecurityContext; -import org.springframework.security.core.context.SecurityContextHolderStrategy; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; -import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.oidc.OidcIdToken; -import org.springframework.security.oauth2.core.oidc.OidcScopes; -import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; -import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; -import org.springframework.security.oauth2.core.user.DefaultOAuth2User; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtDecoder; -import org.springframework.security.oauth2.jwt.JwtDecoderFactory; -import org.springframework.security.oauth2.jwt.JwtException; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.BDDMockito.given; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; - -class RefreshOidcIdTokenHandlerTests { - - private static final String EXISTING_ID_TOKEN_VALUE = "id-token-value"; - - private static final String REFRESHED_ID_TOKEN_VALUE = "new-id-token-value"; - - private static final String EXISTING_ACCESS_TOKEN_VALUE = "token-value"; - - private static final String REFRESHED_ACCESS_TOKEN_VALUE = "new-token-value"; - - private RefreshOidcIdTokenHandler handler; - - private RefreshTokenOAuth2AuthorizedClientProvider provider; - - private ClientRegistration clientRegistration; - - private OAuth2AuthorizedClient authorizedClient; - - private JwtDecoder jwtDecoder; - - private SecurityContext securityContext; - - private OidcIdToken existingIdToken; - - @BeforeEach - void setUp() { - this.handler = new RefreshOidcIdTokenHandler(); - - this.clientRegistration = createClientRegistrationWithScopes(OidcScopes.OPENID); - this.authorizedClient = createAuthorizedClient(this.clientRegistration); - - this.provider = mock(RefreshTokenOAuth2AuthorizedClientProvider.class); - - JwtDecoderFactory jwtDecoderFactory = mock(JwtDecoderFactory.class); - this.jwtDecoder = mock(JwtDecoder.class); - SecurityContextHolderStrategy securityContextHolderStrategy = mock(SecurityContextHolderStrategy.class); - this.securityContext = mock(SecurityContext.class); - - this.handler.setJwtDecoderFactory(jwtDecoderFactory); - this.handler.setSecurityContextHolderStrategy(securityContextHolderStrategy); - - given(jwtDecoderFactory.createDecoder(any())).willReturn(this.jwtDecoder); - given(securityContextHolderStrategy.createEmptyContext()).willReturn(this.securityContext); - given(securityContextHolderStrategy.getContext()).willReturn(this.securityContext); - - Map claims = new HashMap<>(); - claims.put("sub", "subject"); - Jwt existingIdTokenJwt = new Jwt(EXISTING_ID_TOKEN_VALUE, Instant.now(), Instant.now().plusSeconds(3600), - Map.of("alg", "RS256"), claims); - Jwt refreshedIdTokenJwt = new Jwt(REFRESHED_ID_TOKEN_VALUE, Instant.now(), Instant.now().plusSeconds(3600), - Map.of("alg", "RS256"), claims); - - this.existingIdToken = new OidcIdToken(existingIdTokenJwt.getTokenValue(), existingIdTokenJwt.getIssuedAt(), - existingIdTokenJwt.getExpiresAt(), existingIdTokenJwt.getClaims()); - - given(this.jwtDecoder.decode(existingIdTokenJwt.getTokenValue())).willReturn(existingIdTokenJwt); - given(this.jwtDecoder.decode(refreshedIdTokenJwt.getTokenValue())).willReturn(refreshedIdTokenJwt); - } - - @Test - void handleEventWhenValidIdTokenThenUpdatesSecurityContext() { - - DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), - this.existingIdToken); - OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser, - existingUser.getAuthorities(), "registration-id"); - given(this.securityContext.getAuthentication()).willReturn(existingAuth); - - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse - .withToken(REFRESHED_ACCESS_TOKEN_VALUE) - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) - .additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE)) - .build(); - - OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient, - accessTokenResponse); - this.handler.onApplicationEvent(event); - - ArgumentCaptor authenticationCaptor = ArgumentCaptor - .forClass(OAuth2AuthenticationToken.class); - verify(this.securityContext).setAuthentication(authenticationCaptor.capture()); - - OAuth2AuthenticationToken newAuthentication = authenticationCaptor.getValue(); - assertThat(newAuthentication.getPrincipal()).isInstanceOf(DefaultOidcUser.class); - DefaultOidcUser newUser = (DefaultOidcUser) newAuthentication.getPrincipal(); - assertThat(newUser.getIdToken().getTokenValue()).isEqualTo(REFRESHED_ID_TOKEN_VALUE); - } - - @Test - void handleEventWhenAuthorizedClientIsNotOidcThenDoesNothing() { - - this.clientRegistration = createClientRegistrationWithScopes("read"); - this.authorizedClient = createAuthorizedClient(this.clientRegistration); - - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse - .withToken(REFRESHED_ACCESS_TOKEN_VALUE) - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) - .additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE)) - .build(); - - OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient, - accessTokenResponse); - - this.handler.onApplicationEvent(event); - - verify(this.securityContext, never()).setAuthentication(any()); - verify(this.jwtDecoder, never()).decode(any()); - } - - @Test - void handleEventWhenAuthenticationNotOAuth2AuthenticationTokenThenDoesNothing() { - - given(this.securityContext.getAuthentication()).willReturn(mock(TestingAuthenticationToken.class)); - - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse - .withToken(REFRESHED_ACCESS_TOKEN_VALUE) - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) - .additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE)) - .build(); - - OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient, - accessTokenResponse); - - this.handler.onApplicationEvent(event); - - verify(this.securityContext, never()).setAuthentication(any()); - } - - @Test - void handleEventWhenNotOidcUserThenDoesNothing() { - - OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken( - new DefaultOAuth2User(Collections.emptySet(), - Collections.singletonMap("custom-attribute", "test-subject"), "custom-attribute"), - AuthorityUtils.createAuthorityList("ROLE_USER"), "registration-id"); - given(this.securityContext.getAuthentication()).willReturn(existingAuth); - - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse - .withToken(REFRESHED_ACCESS_TOKEN_VALUE) - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) - .additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE)) - .build(); - - OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient, - accessTokenResponse); - - this.handler.onApplicationEvent(event); - - verify(this.securityContext, never()).setAuthentication(any()); - } - - @Test - void handleEventWhenMissingIdTokenThenThrowsException() { - - DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), - this.existingIdToken); - OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser, - existingUser.getAuthorities(), "registration-id"); - given(this.securityContext.getAuthentication()).willReturn(existingAuth); - - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse - .withToken(REFRESHED_ACCESS_TOKEN_VALUE) - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) - .additionalParameters(new HashMap<>()) // missing ID token - .build(); - - OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient, - accessTokenResponse); - - assertThatExceptionOfType(OAuth2AuthenticationException.class) - .isThrownBy(() -> this.handler.onApplicationEvent(event)) - .withMessageContaining("missing_id_token"); - } - - @Test - void handleEventWhenInvalidIdTokenThenThrowsException() { - - DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), - this.existingIdToken); - OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser, - existingUser.getAuthorities(), "registration-id"); - given(this.securityContext.getAuthentication()).willReturn(existingAuth); - - given(this.jwtDecoder.decode(any())).willThrow(new JwtException("Invalid token")); - - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse - .withToken(REFRESHED_ACCESS_TOKEN_VALUE) - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) - .additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, "invalid-id-token")) - .build(); - - OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient, - accessTokenResponse); - - assertThatExceptionOfType(OAuth2AuthenticationException.class) - .isThrownBy(() -> this.handler.onApplicationEvent(event)) - .withMessageContaining("invalid_id_token"); - } - - private ClientRegistration createClientRegistrationWithScopes(String... scope) { - return ClientRegistration.withRegistrationId("registration-id") - .clientId("client-id") - .clientSecret("secret") - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUri("http://localhost") - .scope(scope) - .authorizationUri("https://provider.com/oauth2/authorize") - .tokenUri("https://provider.com/oauth2/token") - .jwkSetUri("https://provider.com/jwk") - .userInfoUri("https://provider.com/user") - .build(); - } - - private static OAuth2AuthorizedClient createAuthorizedClient(ClientRegistration clientRegistration) { - return new OAuth2AuthorizedClient(clientRegistration, "principal-name", - new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, EXISTING_ACCESS_TOKEN_VALUE, Instant.now(), - Instant.now().plusSeconds(3600))); - } - -}