diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java index a2297accd7..4014e27742 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java @@ -66,6 +66,11 @@ */ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2AuthorizationRequestResolver { + /** + * The default base {@code URI} used for authorization requests. + */ + public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization"; + private static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId"; private static final char PATH_DELIMITER = '/'; @@ -86,6 +91,16 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au private Consumer authorizationRequestCustomizer = (customizer) -> { }; + /** + * Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided + * parameters. + * @param clientRegistrationRepository the repository of client registrations + * authorization requests + */ + public DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository clientRegistrationRepository) { + this(clientRegistrationRepository, DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); + } + /** * Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided * parameters. diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java index 65d0be3123..bbf9be1363 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java @@ -87,7 +87,7 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt /** * The default base {@code URI} used for authorization requests. */ - public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization"; + public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = DefaultOAuth2AuthorizationRequestResolver.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI; private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer(); @@ -107,7 +107,7 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt * @param clientRegistrationRepository the repository of client registrations */ public OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository clientRegistrationRepository) { - this(clientRegistrationRepository, DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); + this(new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository)); } /** @@ -119,10 +119,7 @@ public OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository cli */ public OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository clientRegistrationRepository, String authorizationRequestBaseUri) { - Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); - Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty"); - this.authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository, - authorizationRequestBaseUri); + this(new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository, authorizationRequestBaseUri)); } /** diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java index a0abf7132e..1382a0368f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java @@ -97,6 +97,12 @@ public void setUp() { this.authorizationRequestBaseUri); } + @Test + void authorizationRequestBaseUriEqualToRedirectFilter() { + assertThat(DefaultOAuth2AuthorizationRequestResolver.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI) + .isEqualTo(OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); + } + @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException() @@ -568,6 +574,18 @@ public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQuery + "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + "appid=client-id"); } + @Test + public void resolveWhenAuthorizationRequestNoProvideAuthorizationRequestBaseUri() { + OAuth2AuthorizationRequestResolver resolver = new DefaultOAuth2AuthorizationRequestResolver( + this.clientRegistrationRepository); + String requestUri = this.authorizationRequestBaseUri + "/" + this.registration2.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + OAuth2AuthorizationRequest authorizationRequest = resolver.resolve(request); + assertThat(authorizationRequest.getRedirectUri()) + .isEqualTo("http://localhost/login/oauth2/code/" + this.registration2.getRegistrationId()); + } + @Test public void resolveWhenAuthorizationRequestProvideCodeChallengeMethod() { ClientRegistration clientRegistration = this.pkceClientRegistration;