diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index bd3f441359b..d32b560b311 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -19,6 +19,7 @@ import org.reactivestreams.Subscription; import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; @@ -95,6 +96,7 @@ * * @author Rob Winch * @author Joe Grandja + * @author Roman Matiushchenko * @since 5.1 * @see OAuth2AuthorizedClientManager */ @@ -174,7 +176,7 @@ private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManage @Override public void afterPropertiesSet() throws Exception { - Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub))); + Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.liftPublisher((s, sub) -> createRequestContextSubscriberIfNecessary(sub))); } @Override @@ -378,14 +380,22 @@ private Mono mergeRequestAttributesFromContext(ClientRequest requ } private void populateRequestAttributes(Map attrs, Context ctx) { - if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) { - attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME)); - } - if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) { - attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME)); - } - if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) { - attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME)); + RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx); + if (holder != null) { + HttpServletRequest request = holder.getRequest(); + if (request != null) { + attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request); + } + + HttpServletResponse response = holder.getResponse(); + if (response != null) { + attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response); + } + + Authentication authentication = holder.getAuthentication(); + if (authentication != null) { + attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication); + } } } @@ -472,7 +482,7 @@ private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient autho .build(); } - private CoreSubscriber createRequestContextSubscriber(CoreSubscriber delegate) { + CoreSubscriber createRequestContextSubscriberIfNecessary(CoreSubscriber delegate) { HttpServletRequest request = null; HttpServletResponse response = null; ServletRequestAttributes requestAttributes = @@ -482,6 +492,10 @@ private CoreSubscriber createRequestContextSubscriber(CoreSubscriber d response = requestAttributes.getResponse(); } Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + if (authentication == null && request == null && response == null) { + //do not need to create RequestContextSubscriber with empty data + return delegate; + } return new RequestContextSubscriber<>(delegate, request, response, authentication); } @@ -553,34 +567,37 @@ private UnsupportedOperationException unsupported() { } } - private static class RequestContextSubscriber implements CoreSubscriber { - private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME"); + static class RequestContextSubscriber implements CoreSubscriber { + static final String REQUEST_CONTEXT_DATA_HOLDER = + RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER"); private final CoreSubscriber delegate; - private final HttpServletRequest request; - private final HttpServletResponse response; - private final Authentication authentication; + private final Context context; - private RequestContextSubscriber(CoreSubscriber delegate, - HttpServletRequest request, - HttpServletResponse response, - Authentication authentication) { + RequestContextSubscriber(CoreSubscriber delegate, + HttpServletRequest request, + HttpServletResponse response, + Authentication authentication) { this.delegate = delegate; - this.request = request; - this.response = response; - this.authentication = authentication; + + Context parentContext = this.delegate.currentContext(); + Context context; + if (parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER)) { + context = parentContext; + } else { + context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER, new RequestContextDataHolder(request, response, authentication)); + } + + this.context = context; + } + + @Nullable + private static RequestContextDataHolder getRequestContext(Context ctx) { + return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER, null); } @Override public Context currentContext() { - Context context = this.delegate.currentContext(); - if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) { - return context; - } - return Context.of( - CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE, - HTTP_SERVLET_REQUEST_ATTR_NAME, this.request, - HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response, - AUTHENTICATION_ATTR_NAME, this.authentication); + return this.context; } @Override @@ -603,4 +620,33 @@ public void onComplete() { this.delegate.onComplete(); } } + + static class RequestContextDataHolder { + private final HttpServletRequest request; + private final HttpServletResponse response; + private final Authentication authentication; + + RequestContextDataHolder(@Nullable HttpServletRequest request, + @Nullable HttpServletResponse response, + @Nullable Authentication authentication) { + this.request = request; + this.response = response; + this.authentication = authentication; + } + + @Nullable + private HttpServletRequest getRequest() { + return this.request; + } + + @Nullable + private HttpServletResponse getResponse() { + return this.response; + } + + @Nullable + private Authentication getAuthentication() { + return this.authentication; + } + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 4c866c8abb3..f9278d5d644 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -72,6 +72,10 @@ import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; import java.net.URI; import java.time.Duration; @@ -84,6 +88,7 @@ import java.util.function.Consumer; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.Mockito.*; import static org.springframework.http.HttpMethod.GET; @@ -144,9 +149,10 @@ public void setup() { } @After - public void cleanup() { + public void cleanup() throws Exception { SecurityContextHolder.clearContext(); RequestContextHolder.resetRequestAttributes(); + this.function.destroy(); } @Test @@ -633,6 +639,90 @@ public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsN assertThat(getBody(request)).isEmpty(); } + // gh-7228 + @Test + public void afterPropertiesSetWhenHooksInitAndOutsideWebSecurityContextThenShouldNotThrowException() throws Exception { + this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized + assertThatCode(() -> Mono.subscriberContext().block()) + .as("RequestContext Hook brakes application outside of web/security context") + .doesNotThrowAnyException(); + } + + @Test + public void createRequestContextSubscriberIfNecessaryWhenOutsideWebSecurityContextThenReturnOriginalSubscriber() throws Exception { + BaseSubscriber originalSubscriber = new BaseSubscriber() {}; + CoreSubscriber resultSubscriber = this.function.createRequestContextSubscriberIfNecessary(originalSubscriber); + assertThat(resultSubscriber).isSameAs(originalSubscriber); + } + + // gh-7228 + @Test + public void createRequestContextSubscriberWhenRequestResponseProvidedThenCreateWithParentContext() throws Exception { + testRequestContextSubscriber(new MockHttpServletRequest(), new MockHttpServletResponse(), null); + } + + // gh-7228 + @Test + public void createRequestContextSubscriberWhenAuthenticationProvidedThenCreateWithParentContext() throws Exception { + testRequestContextSubscriber(null, null, this.authentication); + } + + @Test + public void createRequestContextSubscriberWhenParentContextHasDataHolderThenShouldReuseParentContext() throws Exception { + RequestContextDataHolder testValue = new RequestContextDataHolder(null, null, null); + final Context parentContext = Context.of(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, testValue); + BaseSubscriber parent = new BaseSubscriber() { + @Override + public Context currentContext() { + return parentContext; + } + }; + + RequestContextSubscriber requestContextSubscriber = + new RequestContextSubscriber<>(parent, null, null, authentication); + + Context resultContext = requestContextSubscriber.currentContext(); + + assertThat(resultContext) + .describedAs("parent context was replaced") + .isSameAs(parentContext); + } + + private void testRequestContextSubscriber(MockHttpServletRequest servletRequest, + MockHttpServletResponse servletResponse, + Authentication authentication) { + String testKey = "test_key"; + String testValue = "test_value"; + + BaseSubscriber parent = new BaseSubscriber() { + @Override + public Context currentContext() { + return Context.of(testKey, testValue); + } + }; + + RequestContextSubscriber requestContextSubscriber = + new RequestContextSubscriber<>(parent, servletRequest, servletResponse, authentication); + + Context resultContext = requestContextSubscriber.currentContext(); + + assertThat(resultContext) + .describedAs("result context is null") + .isNotNull(); + + assertThat(resultContext.getOrEmpty(testKey)) + .describedAs("context is replaced") + .hasValue(testValue); + + Object dataHolder = resultContext.getOrDefault(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, null); + assertThat(dataHolder) + .describedAs("context is not populated with REQUEST_CONTEXT_DATA_HOLDER") + .isNotNull() + .hasFieldOrPropertyWithValue("request", servletRequest) + .hasFieldOrPropertyWithValue("response", servletResponse) + .hasFieldOrPropertyWithValue("authentication", authentication); + } + private static String getBody(ClientRequest request) { final List> messageWriters = new ArrayList<>(); messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));