From 2d3d7a0ca5f2342a1392b2d4e7f2a9c5d70e5325 Mon Sep 17 00:00:00 2001 From: "greg.lee" Date: Sun, 19 Nov 2023 02:28:27 +0900 Subject: [PATCH 1/2] Introduce Customizable AuthorizationFailureHandler in OAuth2AuthorizationRequestRedirectFilter --- ...th2AuthorizationRequestRedirectFilter.java | 41 +++++++++++-------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java index 2923f05bf0..5c257c6de2 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java @@ -97,6 +97,20 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt private RequestCache requestCache = new HttpSessionRequestCache(); + private AuthorizationFailureHandler failureHandler = (request, response, ex) -> { + LogMessage message = LogMessage.format("Authorization Request failed: %s", ex); + if (InvalidClientRegistrationIdException.class.isAssignableFrom(ex.getClass())) { + // Log an invalid registrationId at WARN level to allow these errors to be + // tuned separately from other errors + this.logger.warn(message, ex); + } + else { + this.logger.error(message, ex); + } + response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), + HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase()); + }; + /** * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided * parameters. @@ -163,6 +177,10 @@ public final void setRequestCache(RequestCache requestCache) { this.requestCache = requestCache; } + public final void setFailureHandler(AuthorizationFailureHandler failureHandler) { + this.failureHandler = failureHandler; + } + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { @@ -174,7 +192,7 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse } } catch (Exception ex) { - this.unsuccessfulRedirectForAuthorization(request, response, ex); + this.failureHandler.onAuthorizationFailure(request, response, ex); return; } try { @@ -199,7 +217,7 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse this.sendRedirectForAuthorization(request, response, authorizationRequest); } catch (Exception failed) { - this.unsuccessfulRedirectForAuthorization(request, response, failed); + this.failureHandler.onAuthorizationFailure(request, response, failed); } return; } @@ -222,21 +240,6 @@ private void sendRedirectForAuthorization(HttpServletRequest request, HttpServle authorizationRequest.getAuthorizationRequestUri()); } - private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response, - Exception ex) throws IOException { - LogMessage message = LogMessage.format("Authorization Request failed: %s", ex); - if (InvalidClientRegistrationIdException.class.isAssignableFrom(ex.getClass())) { - // Log an invalid registrationId at WARN level to allow these errors to be - // tuned separately from other errors - this.logger.warn(message, ex); - } - else { - this.logger.error(message, ex); - } - response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), - HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase()); - } - private static final class DefaultThrowableAnalyzer extends ThrowableAnalyzer { @Override @@ -250,4 +253,8 @@ protected void initExtractorMap() { } + public interface AuthorizationFailureHandler { + void onAuthorizationFailure(HttpServletRequest request, HttpServletResponse response, + Exception ex) throws IOException; + } } From e3de24ad83088d87c576706d231979453ba1cdb3 Mon Sep 17 00:00:00 2001 From: "greg.lee" Date: Sat, 10 Feb 2024 22:49:45 +0900 Subject: [PATCH 2/2] Introduce authenticationFailureHandler for OAuth2AuthorizationRequestRedirectFilter | gh-13793 --- ...th2AuthorizationRequestRedirectFilter.java | 62 ++++++++++++------- ...thorizationRequestRedirectFilterTests.java | 32 +++++++++- 2 files changed, 72 insertions(+), 22 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java index 5c257c6de2..65d0be3123 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import org.springframework.core.log.LogMessage; import org.springframework.http.HttpStatus; +import org.springframework.security.core.AuthenticationException; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; @@ -32,6 +33,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.RedirectStrategy; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.util.ThrowableAnalyzer; @@ -97,19 +99,7 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt private RequestCache requestCache = new HttpSessionRequestCache(); - private AuthorizationFailureHandler failureHandler = (request, response, ex) -> { - LogMessage message = LogMessage.format("Authorization Request failed: %s", ex); - if (InvalidClientRegistrationIdException.class.isAssignableFrom(ex.getClass())) { - // Log an invalid registrationId at WARN level to allow these errors to be - // tuned separately from other errors - this.logger.warn(message, ex); - } - else { - this.logger.error(message, ex); - } - response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), - HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase()); - }; + private AuthenticationFailureHandler authenticationFailureHandler = this::unsuccessfulRedirectForAuthorization; /** * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided @@ -177,8 +167,16 @@ public final void setRequestCache(RequestCache requestCache) { this.requestCache = requestCache; } - public final void setFailureHandler(AuthorizationFailureHandler failureHandler) { - this.failureHandler = failureHandler; + /** + * Sets the {@link AuthenticationFailureHandler} used to handle errors redirecting to + * the Authorization Server's Authorization Endpoint. + * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used + * to handle errors redirecting to the Authorization Server's Authorization Endpoint + * @since 6.3 + */ + public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) { + Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); + this.authenticationFailureHandler = authenticationFailureHandler; } @Override @@ -192,7 +190,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse } } catch (Exception ex) { - this.failureHandler.onAuthorizationFailure(request, response, ex); + AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex); + this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException); return; } try { @@ -217,7 +216,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse this.sendRedirectForAuthorization(request, response, authorizationRequest); } catch (Exception failed) { - this.failureHandler.onAuthorizationFailure(request, response, failed); + AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex); + this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException); } return; } @@ -240,6 +240,22 @@ private void sendRedirectForAuthorization(HttpServletRequest request, HttpServle authorizationRequest.getAuthorizationRequestUri()); } + private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response, + AuthenticationException ex) throws IOException { + Throwable cause = ex.getCause(); + LogMessage message = LogMessage.format("Authorization Request failed: %s", cause); + if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) { + // Log an invalid registrationId at WARN level to allow these errors to be + // tuned separately from other errors + this.logger.warn(message, ex); + } + else { + this.logger.error(message, ex); + } + response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), + HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase()); + } + private static final class DefaultThrowableAnalyzer extends ThrowableAnalyzer { @Override @@ -253,8 +269,12 @@ protected void initExtractorMap() { } - public interface AuthorizationFailureHandler { - void onAuthorizationFailure(HttpServletRequest request, HttpServletResponse response, - Exception ex) throws IOException; + private static final class OAuth2AuthorizationRequestException extends AuthenticationException { + + OAuth2AuthorizationRequestException(Throwable cause) { + super(cause.getMessage(), cause); + } + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java index 7123f0169e..59676b1461 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -119,6 +119,11 @@ public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentExcepti assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null)); } + @Test + public void setAuthenticationFailureHandlerIsNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null)); + } + @Test public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception { String requestUri = "/path"; @@ -144,6 +149,31 @@ public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusInternalS assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase()); } + @Test + public void doFilterWhenAuthorizationRequestWithInvalidClientAndCustomFailureHandlerThenCustomError() + throws Exception { + String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + + this.registration1.getRegistrationId() + "-invalid"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.setAuthenticationFailureHandler((request1, response1, ex) -> { + Throwable cause = ex.getCause(); + if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) { + response1.sendError(HttpStatus.BAD_REQUEST.value(), HttpStatus.BAD_REQUEST.getReasonPhrase()); + } + else { + response1.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), + HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase()); + } + }); + this.filter.doFilter(request, response, filterChain); + verifyNoMoreInteractions(filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.BAD_REQUEST.getReasonPhrase()); + } + @Test public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectForAuthorization() throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"