From 58a2284177e2fa4e769f6f35394ed02d91389830 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 24 Sep 2019 20:35:03 -0400 Subject: [PATCH] Align Servlet ExchangeFilterFunction CoreSubscriber Fixes gh-7422 --- .../configuration/OAuth2ImportSelector.java | 34 +++- .../SecurityReactorContextConfiguration.java | 143 +++++++++++++ ...urityReactorContextConfigurationTests.java | 189 ++++++++++++++++++ ...uthorizedClientExchangeFilterFunction.java | 157 ++------------- ...zedClientExchangeFilterFunctionITests.java | 15 +- ...izedClientExchangeFilterFunctionTests.java | 161 +-------------- 6 files changed, 398 insertions(+), 301 deletions(-) create mode 100644 config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java create mode 100644 config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ImportSelector.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ImportSelector.java index 1012b119c6b..3b5b925fe41 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ImportSelector.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ImportSelector.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,23 +15,28 @@ */ package org.springframework.security.config.annotation.web.configuration; -import java.util.ArrayList; -import java.util.List; - import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; import org.springframework.util.ClassUtils; +import java.util.ArrayList; +import java.util.List; + /** - * Used by {@link EnableWebSecurity} to conditionally import {@link OAuth2ClientConfiguration} - * when the {@code spring-security-oauth2-client} module is present on the classpath and - * {@link OAuth2ResourceServerConfiguration} when the {@code spring-security-oauth2-resource-server} - * module is on the classpath. + * Used by {@link EnableWebSecurity} to conditionally import: + * + * * * @author Joe Grandja * @author Josh Cummings * @since 5.1 * @see OAuth2ClientConfiguration + * @see SecurityReactorContextConfiguration + * @see OAuth2ResourceServerConfiguration */ final class OAuth2ImportSelector implements ImportSelector { @@ -39,13 +44,20 @@ final class OAuth2ImportSelector implements ImportSelector { public String[] selectImports(AnnotationMetadata importingClassMetadata) { List imports = new ArrayList<>(); - if (ClassUtils.isPresent( - "org.springframework.security.oauth2.client.registration.ClientRegistration", getClass().getClassLoader())) { + boolean oauth2ClientPresent = ClassUtils.isPresent( + "org.springframework.security.oauth2.client.registration.ClientRegistration", getClass().getClassLoader()); + if (oauth2ClientPresent) { imports.add("org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration"); } + boolean webfluxPresent = ClassUtils.isPresent( + "org.springframework.web.reactive.function.client.ExchangeFilterFunction", getClass().getClassLoader()); + if (webfluxPresent && oauth2ClientPresent) { + imports.add("org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration"); + } + if (ClassUtils.isPresent( - "org.springframework.security.oauth2.server.resource.BearerTokenError", getClass().getClassLoader())) { + "org.springframework.security.oauth2.server.resource.BearerTokenError", getClass().getClassLoader())) { imports.add("org.springframework.security.config.annotation.web.configuration.OAuth2ResourceServerConfiguration"); } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java new file mode 100644 index 00000000000..ef524167b51 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java @@ -0,0 +1,143 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.web.configuration; + +import org.reactivestreams.Subscription; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.HashMap; +import java.util.Map; + +/** + * {@link Configuration} that adds a {@code Publisher} for the last operator created + * in every {@code Mono} or {@code Flux}. + * + *

+ * The {@code Publisher} is solely responsible for adding + * the current {@code HttpServletRequest}, {@code HttpServletResponse} and {@code Authentication} + * to the Reactor {@code Context} so that it's accessible in every flow, if required. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2ImportSelector + */ +@Configuration(proxyBeanMethods = false) +class SecurityReactorContextConfiguration { + + @Bean + SecurityReactorContextSubscriberRegistrar securityReactorContextSubscriberRegistrar() { + return new SecurityReactorContextSubscriberRegistrar(); + } + + static class SecurityReactorContextSubscriberRegistrar implements InitializingBean, DisposableBean { + private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR"; + + @Override + public void afterPropertiesSet() throws Exception { + Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, + Operators.liftPublisher((s, sub) -> createSubscriberIfNecessary(sub))); + } + + @Override + public void destroy() throws Exception { + Hooks.resetOnLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY); + } + + CoreSubscriber createSubscriberIfNecessary(CoreSubscriber delegate) { + HttpServletRequest servletRequest = null; + HttpServletResponse servletResponse = null; + ServletRequestAttributes requestAttributes = + (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); + if (requestAttributes != null) { + servletRequest = requestAttributes.getRequest(); + servletResponse = requestAttributes.getResponse(); + } + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + if (authentication == null && servletRequest == null && servletResponse == null) { + // No need to create Subscriber so return original + return delegate; + } + + Map attributes = new HashMap<>(); + if (servletRequest != null) { + attributes.put(HttpServletRequest.class, servletRequest); + } + if (servletResponse != null) { + attributes.put(HttpServletResponse.class, servletResponse); + } + if (authentication != null) { + attributes.put(Authentication.class, authentication); + } + return new SecurityReactorContextSubscriber<>(delegate, attributes); + } + } + + static class SecurityReactorContextSubscriber implements CoreSubscriber { + static final String SECURITY_CONTEXT_ATTRIBUTES = "org.springframework.security.SECURITY_CONTEXT_ATTRIBUTES"; + private final CoreSubscriber delegate; + private final Context context; + + SecurityReactorContextSubscriber(CoreSubscriber delegate, Map attributes) { + this.delegate = delegate; + Context currentContext = this.delegate.currentContext(); + Context context; + if (currentContext.hasKey(SECURITY_CONTEXT_ATTRIBUTES)) { + context = currentContext; + } else { + context = currentContext.put(SECURITY_CONTEXT_ATTRIBUTES, attributes); + } + this.context = context; + } + + @Override + public Context currentContext() { + return this.context; + } + + @Override + public void onSubscribe(Subscription s) { + this.delegate.onSubscribe(s); + } + + @Override + public void onNext(T t) { + this.delegate.onNext(t); + } + + @Override + public void onError(Throwable t) { + this.delegate.onError(t); + } + + @Override + public void onComplete() { + this.delegate.onComplete(); + } + } +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java new file mode 100644 index 00000000000..f3d81f67e82 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java @@ -0,0 +1,189 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.web.configuration; + +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.springframework.http.HttpStatus; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.util.context.Context; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.entry; +import static org.springframework.http.HttpMethod.GET; +import static org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES; + +/** + * Tests for {@link SecurityReactorContextConfiguration}. + * + * @author Joe Grandja + * @since 5.2 + */ +public class SecurityReactorContextConfigurationTests { + private MockHttpServletRequest servletRequest; + private MockHttpServletResponse servletResponse; + private Authentication authentication; + private SecurityReactorContextConfiguration.SecurityReactorContextSubscriberRegistrar subscriberRegistrar = + new SecurityReactorContextConfiguration.SecurityReactorContextSubscriberRegistrar(); + + @Rule + public final SpringTestRule spring = new SpringTestRule(); + + @Before + public void setup() { + this.servletRequest = new MockHttpServletRequest(); + this.servletResponse = new MockHttpServletResponse(); + this.authentication = new TestingAuthenticationToken("principal", "password"); + } + + @After + public void cleanup() { + SecurityContextHolder.clearContext(); + RequestContextHolder.resetRequestAttributes(); + } + + @Test + public void createSubscriberIfNecessaryWhenWebSecurityContextUnavailableThenReturnOriginalSubscriber() { + BaseSubscriber originalSubscriber = new BaseSubscriber() {}; + CoreSubscriber resultSubscriber = this.subscriberRegistrar.createSubscriberIfNecessary(originalSubscriber); + assertThat(resultSubscriber).isSameAs(originalSubscriber); + } + + @Test + public void createSubscriberIfNecessaryWhenWebSecurityContextAvailableThenCreateWithParentContext() { + RequestContextHolder.setRequestAttributes( + new ServletRequestAttributes(this.servletRequest, this.servletResponse)); + SecurityContextHolder.getContext().setAuthentication(this.authentication); + + String testKey = "test_key"; + String testValue = "test_value"; + + BaseSubscriber parent = new BaseSubscriber() { + @Override + public Context currentContext() { + return Context.of(testKey, testValue); + } + }; + CoreSubscriber subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(parent); + + Context resultContext = subscriber.currentContext(); + + assertThat(resultContext.getOrEmpty(testKey)).hasValue(testValue); + Map securityContextAttributes = resultContext.getOrDefault(SECURITY_CONTEXT_ATTRIBUTES, null); + assertThat(securityContextAttributes).hasSize(3); + assertThat(securityContextAttributes).contains( + entry(HttpServletRequest.class, this.servletRequest), + entry(HttpServletResponse.class, this.servletResponse), + entry(Authentication.class, this.authentication)); + } + + @Test + public void createSubscriberIfNecessaryWhenParentContextContainsSecurityContextAttributesThenUseParentContext() { + RequestContextHolder.setRequestAttributes( + new ServletRequestAttributes(this.servletRequest, this.servletResponse)); + SecurityContextHolder.getContext().setAuthentication(this.authentication); + + Context parentContext = Context.of(SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>()); + BaseSubscriber parent = new BaseSubscriber() { + @Override + public Context currentContext() { + return parentContext; + } + }; + CoreSubscriber subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(parent); + + Context resultContext = subscriber.currentContext(); + assertThat(resultContext).isSameAs(parentContext); + } + + @Test + public void createPublisherWhenLastOperatorAddedThenSecurityContextAttributesAvailable() { + // Trigger the importing of SecurityReactorContextConfiguration via OAuth2ImportSelector + this.spring.register(SecurityConfig.class).autowire(); + + // Setup for SecurityReactorContextSubscriberRegistrar + RequestContextHolder.setRequestAttributes( + new ServletRequestAttributes(this.servletRequest, this.servletResponse)); + SecurityContextHolder.getContext().setAuthentication(this.authentication); + + ClientResponse clientResponseOk = ClientResponse.create(HttpStatus.OK).build(); + + ExchangeFilterFunction filter = (req, next) -> + Mono.subscriberContext() + .filter(ctx -> ctx.hasKey(SECURITY_CONTEXT_ATTRIBUTES)) + .map(ctx -> ctx.get(SECURITY_CONTEXT_ATTRIBUTES)) + .cast(Map.class) + .map(attributes -> { + if (attributes.containsKey(HttpServletRequest.class) && + attributes.containsKey(HttpServletResponse.class) && + attributes.containsKey(Authentication.class)) { + return clientResponseOk; + } else { + return ClientResponse.create(HttpStatus.NOT_FOUND).build(); + } + }); + + ClientRequest clientRequest = ClientRequest.create(GET, URI.create("https://example.com")).build(); + MockExchangeFunction exchange = new MockExchangeFunction(); + + Map expectedContextAttributes = new HashMap<>(); + expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest); + expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse); + expectedContextAttributes.put(Authentication.class, this.authentication); + + Mono clientResponseMono = filter.filter(clientRequest, exchange) + .flatMap(response -> filter.filter(clientRequest, exchange)); + + StepVerifier.create(clientResponseMono) + .expectAccessibleContext() + .contains(SECURITY_CONTEXT_ATTRIBUTES, expectedContextAttributes) + .then() + .expectNext(clientResponseOk) + .verifyComplete(); + } + + @EnableWebSecurity + static class SecurityConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 72b372300f3..147e3255c0f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -16,10 +16,6 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; -import org.reactivestreams.Subscription; -import org.springframework.beans.factory.DisposableBean; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.lang.Nullable; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; @@ -47,10 +43,7 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import org.springframework.web.reactive.function.client.ExchangeFunction; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; -import reactor.core.publisher.Operators; import reactor.core.scheduler.Schedulers; import reactor.util.context.Context; @@ -100,8 +93,10 @@ * @since 5.1 * @see OAuth2AuthorizedClientManager */ -public final class ServletOAuth2AuthorizedClientExchangeFilterFunction - implements ExchangeFilterFunction, InitializingBean, DisposableBean { +public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction { + + // Same key as in SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES + static final String SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY = "org.springframework.security.SECURITY_CONTEXT_ATTRIBUTES"; /** * The request attribute name used to locate the {@link OAuth2AuthorizedClient}. @@ -112,8 +107,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName(); private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName(); - private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName(); - private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken( "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); @@ -175,16 +168,6 @@ private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManage return authorizedClientManager; } - @Override - public void afterPropertiesSet() { - Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.liftPublisher((s, sub) -> createRequestContextSubscriberIfNecessary(sub))); - } - - @Override - public void destroy() { - Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY); - } - /** * Sets the {@link OAuth2AccessTokenResponseClient} used for getting an {@link OAuth2AuthorizedClient} for the client_credentials grant. * @@ -382,22 +365,22 @@ private Mono mergeRequestAttributesFromContext(ClientRequest requ } private void populateRequestAttributes(Map attrs, Context ctx) { - RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx); - if (holder != null) { - HttpServletRequest request = holder.getRequest(); - if (request != null) { - attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request); - } - - HttpServletResponse response = holder.getResponse(); - if (response != null) { - attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response); - } - - Authentication authentication = holder.getAuthentication(); - if (authentication != null) { - attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication); - } + // NOTE: SecurityReactorContextConfiguration.SecurityReactorContextSubscriber adds this key + if (!ctx.hasKey(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY)) { + return; + } + Map contextAttributes = ctx.get(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY); + HttpServletRequest servletRequest = (HttpServletRequest) contextAttributes.get(HttpServletRequest.class); + if (servletRequest != null) { + attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, servletRequest); + } + HttpServletResponse servletResponse = (HttpServletResponse) contextAttributes.get(HttpServletResponse.class); + if (servletResponse != null) { + attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, servletResponse); + } + Authentication authentication = (Authentication) contextAttributes.get(Authentication.class); + if (authentication != null) { + attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication); } } @@ -503,23 +486,6 @@ private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient autho .build(); } - CoreSubscriber createRequestContextSubscriberIfNecessary(CoreSubscriber delegate) { - HttpServletRequest request = null; - HttpServletResponse response = null; - ServletRequestAttributes requestAttributes = - (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); - if (requestAttributes != null) { - request = requestAttributes.getRequest(); - response = requestAttributes.getResponse(); - } - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - if (authentication == null && request == null && response == null) { - //do not need to create RequestContextSubscriber with empty data - return delegate; - } - return new RequestContextSubscriber<>(delegate, request, response, authentication); - } - static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map attrs) { return (OAuth2AuthorizedClient) attrs.get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME); } @@ -587,87 +553,4 @@ private UnsupportedOperationException unsupported() { return new UnsupportedOperationException("Not Supported"); } } - - static class RequestContextSubscriber implements CoreSubscriber { - static final String REQUEST_CONTEXT_DATA_HOLDER = - RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER"); - private final CoreSubscriber delegate; - private final Context context; - - RequestContextSubscriber(CoreSubscriber delegate, - HttpServletRequest request, - HttpServletResponse response, - Authentication authentication) { - this.delegate = delegate; - - Context parentContext = this.delegate.currentContext(); - Context context; - if (parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER)) { - context = parentContext; - } else { - context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER, new RequestContextDataHolder(request, response, authentication)); - } - - this.context = context; - } - - @Nullable - private static RequestContextDataHolder getRequestContext(Context ctx) { - return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER, null); - } - - @Override - public Context currentContext() { - return this.context; - } - - @Override - public void onSubscribe(Subscription s) { - this.delegate.onSubscribe(s); - } - - @Override - public void onNext(T t) { - this.delegate.onNext(t); - } - - @Override - public void onError(Throwable t) { - this.delegate.onError(t); - } - - @Override - public void onComplete() { - this.delegate.onComplete(); - } - } - - static class RequestContextDataHolder { - private final HttpServletRequest request; - private final HttpServletResponse response; - private final Authentication authentication; - - RequestContextDataHolder(@Nullable HttpServletRequest request, - @Nullable HttpServletResponse response, - @Nullable Authentication authentication) { - this.request = request; - this.response = response; - this.authentication = authentication; - } - - @Nullable - private HttpServletRequest getRequest() { - return this.request; - } - - @Nullable - private HttpServletResponse getResponse() { - return this.response; - } - - @Nullable - private Authentication getAuthentication() { - return this.authentication; - } - } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java index 1ba6559cd9f..4fc245d31b6 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java @@ -43,16 +43,20 @@ import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.reactive.function.client.WebClient; import reactor.blockhound.BlockHound; +import reactor.util.context.Context; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.time.Duration; import java.time.Instant; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.*; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; /** @@ -104,7 +108,6 @@ public void removeAuthorizedClient(String clientRegistrationId, Authentication p }); this.authorizedClientFilter = new ServletOAuth2AuthorizedClientExchangeFilterFunction( this.clientRegistrationRepository, this.authorizedClientRepository); - this.authorizedClientFilter.afterPropertiesSet(); this.server = new MockWebServer(); this.server.start(); this.serverUrl = this.server.url("/").toString(); @@ -120,7 +123,6 @@ public void removeAuthorizedClient(String clientRegistrationId, Authentication p @After public void cleanup() throws Exception { - this.authorizedClientFilter.destroy(); this.server.shutdown(); SecurityContextHolder.clearContext(); RequestContextHolder.resetRequestAttributes(); @@ -248,6 +250,7 @@ public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() { .attributes(clientRegistrationId(clientRegistration2.getRegistrationId())) .retrieve() .bodyToMono(String.class)) + .subscriberContext(context()) .block(); assertThat(this.server.getRequestCount()).isEqualTo(4); @@ -259,6 +262,14 @@ public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() { assertThat(authorizedClientCaptor.getAllValues().get(1).getClientRegistration()).isSameAs(clientRegistration2); } + private Context context() { + Map contextAttributes = new HashMap<>(); + contextAttributes.put(HttpServletRequest.class, this.request); + contextAttributes.put(HttpServletResponse.class, this.response); + contextAttributes.put(Authentication.class, this.authentication); + return Context.of(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, contextAttributes); + } + private MockResponse jsonResponse(String json) { return new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 7eb7cfa8aa4..cb1d831d9ea 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -76,12 +76,10 @@ import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.BaseSubscriber; -import reactor.core.publisher.Mono; import reactor.util.context.Context; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import java.net.URI; import java.time.Duration; import java.time.Instant; @@ -93,7 +91,6 @@ import java.util.function.Consumer; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.Mockito.*; import static org.springframework.http.HttpMethod.GET; @@ -163,7 +160,6 @@ public void setup() { public void cleanup() throws Exception { SecurityContextHolder.clearContext(); RequestContextHolder.resetRequestAttributes(); - this.function.destroy(); } @Test @@ -591,18 +587,15 @@ public void filterWhenNotExpiredThenShouldRefreshFalse() { // gh-6483 @Test public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { - this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized this.function.setDefaultOAuth2AuthorizedClient(true); MockHttpServletRequest servletRequest = new MockHttpServletRequest(); MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse)); OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( user, authorities, this.registration.getRegistrationId()); - SecurityContextHolder.getContext().setAuthentication(authentication); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( this.registration, "principalName", this.accessToken); @@ -619,12 +612,13 @@ public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { // Default request attributes NOT set final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build(); + Context context = context(servletRequest, servletResponse, authentication); + this.function.filter(request1, this.exchange) .flatMap(response -> this.function.filter(request2, this.exchange)) + .subscriberContext(context) .block(); - this.function.destroy(); // Hooks.onLastOperator() released - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(2); @@ -641,147 +635,12 @@ public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { assertThat(getBody(request)).isEmpty(); } - @Test - public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() { -// this.function.afterPropertiesSet(); // Hooks.onLastOperator() NOT initialized - this.function.setDefaultOAuth2AuthorizedClient(true); - - MockHttpServletRequest servletRequest = new MockHttpServletRequest(); - MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse)); - - OAuth2User user = mock(OAuth2User.class); - List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( - user, authorities, this.registration.getRegistrationId()); - SecurityContextHolder.getContext().setAuthentication(authentication); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build(); - - this.function.filter(request, this.exchange).block(); - - List requests = this.exchange.getRequests(); - assertThat(requests).hasSize(1); - - request = requests.get(0); - assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); - assertThat(request.url().toASCIIString()).isEqualTo("https://example.com"); - assertThat(request.method()).isEqualTo(HttpMethod.GET); - assertThat(getBody(request)).isEmpty(); - } - - @Test - public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception { - this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized - this.function.destroy(); // Hooks.onLastOperator() released - this.function.setDefaultOAuth2AuthorizedClient(true); - - MockHttpServletRequest servletRequest = new MockHttpServletRequest(); - MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse)); - - OAuth2User user = mock(OAuth2User.class); - List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( - user, authorities, this.registration.getRegistrationId()); - SecurityContextHolder.getContext().setAuthentication(authentication); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build(); - - this.function.filter(request, this.exchange).block(); - - List requests = this.exchange.getRequests(); - assertThat(requests).hasSize(1); - - request = requests.get(0); - assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); - assertThat(request.url().toASCIIString()).isEqualTo("https://example.com"); - assertThat(request.method()).isEqualTo(HttpMethod.GET); - assertThat(getBody(request)).isEmpty(); - } - - // gh-7228 - @Test - public void afterPropertiesSetWhenHooksInitAndOutsideWebSecurityContextThenShouldNotThrowException() throws Exception { - this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized - assertThatCode(() -> Mono.subscriberContext().block()) - .as("RequestContext Hook brakes application outside of web/security context") - .doesNotThrowAnyException(); - } - - @Test - public void createRequestContextSubscriberIfNecessaryWhenOutsideWebSecurityContextThenReturnOriginalSubscriber() throws Exception { - BaseSubscriber originalSubscriber = new BaseSubscriber() {}; - CoreSubscriber resultSubscriber = this.function.createRequestContextSubscriberIfNecessary(originalSubscriber); - assertThat(resultSubscriber).isSameAs(originalSubscriber); - } - - // gh-7228 - @Test - public void createRequestContextSubscriberWhenRequestResponseProvidedThenCreateWithParentContext() throws Exception { - testRequestContextSubscriber(new MockHttpServletRequest(), new MockHttpServletResponse(), null); - } - - // gh-7228 - @Test - public void createRequestContextSubscriberWhenAuthenticationProvidedThenCreateWithParentContext() throws Exception { - testRequestContextSubscriber(null, null, this.authentication); - } - - @Test - public void createRequestContextSubscriberWhenParentContextHasDataHolderThenShouldReuseParentContext() throws Exception { - RequestContextDataHolder testValue = new RequestContextDataHolder(null, null, null); - final Context parentContext = Context.of(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, testValue); - BaseSubscriber parent = new BaseSubscriber() { - @Override - public Context currentContext() { - return parentContext; - } - }; - - RequestContextSubscriber requestContextSubscriber = - new RequestContextSubscriber<>(parent, null, null, authentication); - - Context resultContext = requestContextSubscriber.currentContext(); - - assertThat(resultContext) - .describedAs("parent context was replaced") - .isSameAs(parentContext); - } - - private void testRequestContextSubscriber(MockHttpServletRequest servletRequest, - MockHttpServletResponse servletResponse, - Authentication authentication) { - String testKey = "test_key"; - String testValue = "test_value"; - - BaseSubscriber parent = new BaseSubscriber() { - @Override - public Context currentContext() { - return Context.of(testKey, testValue); - } - }; - - RequestContextSubscriber requestContextSubscriber = - new RequestContextSubscriber<>(parent, servletRequest, servletResponse, authentication); - - Context resultContext = requestContextSubscriber.currentContext(); - - assertThat(resultContext) - .describedAs("result context is null") - .isNotNull(); - - assertThat(resultContext.getOrEmpty(testKey)) - .describedAs("context is replaced") - .hasValue(testValue); - - Object dataHolder = resultContext.getOrDefault(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, null); - assertThat(dataHolder) - .describedAs("context is not populated with REQUEST_CONTEXT_DATA_HOLDER") - .isNotNull() - .hasFieldOrPropertyWithValue("request", servletRequest) - .hasFieldOrPropertyWithValue("response", servletResponse) - .hasFieldOrPropertyWithValue("authentication", authentication); + private Context context(HttpServletRequest servletRequest, HttpServletResponse servletResponse, Authentication authentication) { + Map contextAttributes = new HashMap<>(); + contextAttributes.put(HttpServletRequest.class, servletRequest); + contextAttributes.put(HttpServletResponse.class, servletResponse); + contextAttributes.put(Authentication.class, authentication); + return Context.of(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, contextAttributes); } private static String getBody(ClientRequest request) {