diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationContext.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationContext.java new file mode 100644 index 000000000..d6ac24dd9 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationContext.java @@ -0,0 +1,185 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.Assert; + +/** + * An {@link OAuth2AuthenticationContext} that holds an + * {@link OAuth2DeviceVerificationAuthenticationToken} and additional information and is + * used when validating the OAuth 2.0 Device Verification Request parameters, as well as + * determining if authorization consent is required. + * + * @author Dinesh Gupta + * @since 2.0.0 + * @see OAuth2AuthenticationContext + * @see OAuth2DeviceVerificationAuthenticationToken + * @see OAuth2DeviceVerificationAuthenticationProvider#setAuthorizationConsentRequired(java.util.function.Predicate) + */ +public final class OAuth2DeviceVerificationAuthenticationContext implements OAuth2AuthenticationContext { + + private final Map context; + + private OAuth2DeviceVerificationAuthenticationContext(Map context) { + this.context = Collections.unmodifiableMap(new HashMap<>(context)); + } + + @SuppressWarnings("unchecked") + @Nullable + @Override + public T getAuthentication() { + return (T) get(OAuth2DeviceVerificationAuthenticationToken.class); + } + + @Override + public boolean hasKey(Object key) { + Assert.notNull(key, "key cannot be null"); + return this.context.containsKey(key); + } + + @SuppressWarnings("unchecked") + @Nullable + @Override + public V get(Object key) { + return hasKey(key) ? (V) this.context.get(key) : null; + } + + /** + * Returns the {@link RegisteredClient registered client}. + * @return the {@link RegisteredClient} + */ + public RegisteredClient getRegisteredClient() { + return get(RegisteredClient.class); + } + + /** + * Returns the {@link OAuth2Authorization authorization}. + * @return the {@link OAuth2Authorization}, or {@code null} if not available + */ + @Nullable + public OAuth2Authorization getAuthorization() { + return get(OAuth2Authorization.class); + } + + /** + * Returns the {@link OAuth2AuthorizationConsent authorization consent}. + * @return the {@link OAuth2AuthorizationConsent}, or {@code null} if not available + */ + @Nullable + public OAuth2AuthorizationConsent getAuthorizationConsent() { + return get(OAuth2AuthorizationConsent.class); + } + + /** + * Returns the requested scopes. Never {@code null}; always a {@link Set} (possibly + * empty). + * @return the requested scopes + */ + @SuppressWarnings("unchecked") + public Set getRequestedScopes() { + Set scopes = get(Set.class); + return scopes != null ? scopes : Collections.emptySet(); + } + + /** + * Constructs a new {@link Builder} with the provided + * {@link OAuth2DeviceVerificationAuthenticationToken}. + * @param authentication the {@link OAuth2DeviceVerificationAuthenticationToken} + * @return the {@link Builder} + */ + public static Builder with(OAuth2DeviceVerificationAuthenticationToken authentication) { + return new Builder(authentication); + } + + /** + * A builder for {@link OAuth2DeviceVerificationAuthenticationContext}. + */ + public static final class Builder { + + private final Map context = new HashMap<>(); + + private Builder(OAuth2DeviceVerificationAuthenticationToken authentication) { + Assert.notNull(authentication, "authentication cannot be null"); + context.put(OAuth2DeviceVerificationAuthenticationToken.class, authentication); + } + + /** + * Sets the {@link RegisteredClient registered client}. + * @param registeredClient the {@link RegisteredClient} + * @return the {@link Builder} for further configuration + */ + public Builder registeredClient(RegisteredClient registeredClient) { + context.put(RegisteredClient.class, registeredClient); + return this; + } + + /** + * Sets the {@link OAuth2Authorization authorization}. + * @param authorization the {@link OAuth2Authorization} + * @return the {@link Builder} for further configuration + */ + public Builder authorization(@Nullable OAuth2Authorization authorization) { + if (authorization != null) { + context.put(OAuth2Authorization.class, authorization); + } + return this; + } + + /** + * Sets the {@link OAuth2AuthorizationConsent authorization consent}. + * @param authorizationConsent the {@link OAuth2AuthorizationConsent} + * @return the {@link Builder} for further configuration + */ + public Builder authorizationConsent(@Nullable OAuth2AuthorizationConsent authorizationConsent) { + if (authorizationConsent != null) { + context.put(OAuth2AuthorizationConsent.class, authorizationConsent); + } + return this; + } + + /** + * Sets the requested scopes. Never {@code null}; always a {@link Set} (possibly + * empty). + * @param requestedScopes the requested scopes + * @return the {@link Builder} for further configuration + */ + public Builder requestedScopes(@Nullable Set requestedScopes) { + context.put(Set.class, requestedScopes != null ? requestedScopes : Collections.emptySet()); + return this; + } + + /** + * Builds a new {@link OAuth2DeviceVerificationAuthenticationContext}. + * @return the {@link OAuth2DeviceVerificationAuthenticationContext} + */ + public OAuth2DeviceVerificationAuthenticationContext build() { + Assert.notNull(context.get(RegisteredClient.class), "registeredClient cannot be null"); + return new OAuth2DeviceVerificationAuthenticationContext(context); + } + + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java index b631c088d..0f676f7ae 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java @@ -18,6 +18,7 @@ import java.security.Principal; import java.util.Base64; import java.util.Set; +import java.util.function.Predicate; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -78,6 +79,8 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut private final OAuth2AuthorizationConsentService authorizationConsentService; + private Predicate authorizationConsentRequired = OAuth2DeviceVerificationAuthenticationProvider::isAuthorizationConsentRequired; + /** * Constructs an {@code OAuth2DeviceVerificationAuthenticationProvider} using the * provided parameters. @@ -140,12 +143,19 @@ public Authentication authenticate(Authentication authentication) throws Authent this.logger.trace("Retrieved registered client"); } + OAuth2DeviceVerificationAuthenticationContext.Builder authenticationContextBuilder = OAuth2DeviceVerificationAuthenticationContext + .with(deviceVerificationAuthentication) + .registeredClient(registeredClient) + .authorization(authorization); + Set requestedScopes = authorization.getAttribute(OAuth2ParameterNames.SCOPE); + authenticationContextBuilder.requestedScopes(requestedScopes); OAuth2AuthorizationConsent currentAuthorizationConsent = this.authorizationConsentService .findById(registeredClient.getId(), principal.getName()); + authenticationContextBuilder.authorizationConsent(currentAuthorizationConsent); - if (requiresAuthorizationConsent(requestedScopes, currentAuthorizationConsent)) { + if (this.authorizationConsentRequired.test(authenticationContextBuilder.build())) { String state = DEFAULT_STATE_GENERATOR.generateKey(); authorization = OAuth2Authorization.from(authorization) .principalName(principal.getName()) @@ -201,13 +211,38 @@ public boolean supports(Class authentication) { return OAuth2DeviceVerificationAuthenticationToken.class.isAssignableFrom(authentication); } - private static boolean requiresAuthorizationConsent(Set requestedScopes, - OAuth2AuthorizationConsent authorizationConsent) { + /** + * Sets the {@code Predicate} used to determine if authorization consent is required + * during the OAuth 2.0 Device Verification flow. + * + *

+ * The {@link OAuth2DeviceVerificationAuthenticationContext} provides the predicate + * access to the following context attributes: + *

    + *
  • The {@link RegisteredClient} associated with the authorization request.
  • + *
  • The {@link OAuth2Authorization} associated with the device verification.
  • + *
  • The {@link OAuth2AuthorizationConsent} previously granted to the + * {@link RegisteredClient}, or {@code null} if not available.
  • + *
+ *

+ * @param authorizationConsentRequired the {@code Predicate} used to determine if + * authorization consent is required for device verification + * @since 2.0.0 + */ + public void setAuthorizationConsentRequired( + Predicate authorizationConsentRequired) { + Assert.notNull(authorizationConsentRequired, "authorizationConsentRequired cannot be null"); + this.authorizationConsentRequired = authorizationConsentRequired; + } + + private static boolean isAuthorizationConsentRequired( + OAuth2DeviceVerificationAuthenticationContext authenticationContext) { - if (authorizationConsent != null && authorizationConsent.getScopes().containsAll(requestedScopes)) { + if (authenticationContext.getAuthorizationConsent() != null && authenticationContext.getAuthorizationConsent() + .getScopes() + .containsAll(authenticationContext.getRequestedScopes())) { return false; } - return true; } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java index fd6a54d60..576acdffa 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Predicate; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -50,10 +51,12 @@ import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; import org.springframework.security.oauth2.server.authorization.context.TestAuthorizationServerContext; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; +import org.springframework.security.oauth2.server.authorization.settings.ClientSettings; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; @@ -124,6 +127,13 @@ public void constructorWhenAuthorizationConsentServiceIsNullThenThrowIllegalArgu // @formatter:on } + @Test + public void setAuthorizationConsentRequiredWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authenticationProvider.setAuthorizationConsentRequired(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationConsentRequired cannot be null"); + } + @Test public void supportsWhenTypeOAuth2DeviceVerificationAuthenticationTokenThenReturnTrue() { assertThat(this.authenticationProvider.supports(OAuth2DeviceVerificationAuthenticationToken.class)).isTrue(); @@ -381,6 +391,81 @@ public void authenticateWhenAuthorizationConsentExistsAndRequestedScopesDoNotMat .isEqualTo(authenticationResult.getState()); } + @Test + void authenticateWhenPredicateTrueThenReturnsConsentToken() { + @SuppressWarnings("unchecked") + Predicate consentPredicate = mock(Predicate.class); + given(consentPredicate.test(any())).willReturn(true); + authenticationProvider.setAuthorizationConsentRequired(consentPredicate); + + RegisteredClient client = TestRegisteredClients.registeredClient().build(); + given(registeredClientRepository.findById(client.getId())).willReturn(client); + + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(client) + .token(createDeviceCode()) + .token(createUserCode()) + .attribute(OAuth2ParameterNames.SCOPE, client.getScopes()) + .build(); + + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password"); + principal.setAuthenticated(true); + + OAuth2DeviceVerificationAuthenticationToken authRequest = new OAuth2DeviceVerificationAuthenticationToken( + principal, USER_CODE, Collections.emptyMap()); + + given(authorizationService.findByToken(USER_CODE, + OAuth2DeviceVerificationAuthenticationProvider.USER_CODE_TOKEN_TYPE)) + .willReturn(authorization); + given(authorizationConsentService.findById(client.getId(), principal.getName())).willReturn(null); + + Authentication result = authenticationProvider.authenticate(authRequest); + + assertThat(result).isInstanceOf(OAuth2DeviceAuthorizationConsentAuthenticationToken.class); + OAuth2DeviceAuthorizationConsentAuthenticationToken consentToken = (OAuth2DeviceAuthorizationConsentAuthenticationToken) result; + + assertThat(consentToken.isAuthenticated()).isTrue(); + assertThat(consentToken.getClientId()).isEqualTo(client.getClientId()); + assertThat(consentToken.getPrincipal()).isEqualTo(authRequest.getPrincipal()); + assertThat(consentToken.getUserCode()).isEqualTo(authRequest.getUserCode()); + assertThat(consentToken.getRequestedScopes()).containsExactlyInAnyOrderElementsOf(client.getScopes()); + assertThat(consentToken.getState()).isNotNull(); + + verify(consentPredicate).test(any()); + } + + @Test + void authenticateWhenPredicateFalseThenSkipsConsentPage() { + RegisteredClient client = TestRegisteredClients.registeredClient() + .clientSettings(ClientSettings.builder().requireAuthorizationConsent(false).build()) + .build(); + + authenticationProvider.setAuthorizationConsentRequired( + ctx -> ctx.getRegisteredClient().getClientSettings().isRequireAuthorizationConsent()); + + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(client) + .token(createDeviceCode()) + .token(createUserCode()) + .attribute(OAuth2ParameterNames.SCOPE, client.getScopes()) + .build(); + + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password"); + principal.setAuthenticated(true); + + OAuth2DeviceVerificationAuthenticationToken authRequest = new OAuth2DeviceVerificationAuthenticationToken( + principal, USER_CODE, Collections.emptyMap()); + + given(registeredClientRepository.findById(client.getId())).willReturn(client); + given(authorizationService.findByToken(USER_CODE, + OAuth2DeviceVerificationAuthenticationProvider.USER_CODE_TOKEN_TYPE)) + .willReturn(authorization); + given(authorizationConsentService.findById(client.getId(), principal.getName())).willReturn(null); + + Authentication result = authenticationProvider.authenticate(authRequest); + + assertThat(result).isInstanceOf(OAuth2DeviceVerificationAuthenticationToken.class); + assertThat(result.isAuthenticated()).isTrue(); + } + private static void mockAuthorizationServerContext() { AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().build(); TestAuthorizationServerContext authorizationServerContext = new TestAuthorizationServerContext(