diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunction.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunction.java deleted file mode 100644 index 820c05ac48e..00000000000 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunction.java +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright 2002-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.server.resource.web; - -import java.util.Map; -import java.util.function.Consumer; - -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Hooks; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Operators; -import reactor.util.context.Context; - -import org.springframework.beans.factory.DisposableBean; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.lang.Nullable; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.core.AbstractOAuth2Token; -import org.springframework.web.reactive.function.client.ClientRequest; -import org.springframework.web.reactive.function.client.ClientResponse; -import org.springframework.web.reactive.function.client.ExchangeFilterFunction; -import org.springframework.web.reactive.function.client.ExchangeFunction; -import org.springframework.web.reactive.function.client.WebClient; - -/** - * An {@link ExchangeFilterFunction} that adds the - * Bearer Token - * from an existing {@link AbstractOAuth2Token} tied to the current {@link Authentication}. - * - * Suitable for Servlet applications, applying it to a typical {@link org.springframework.web.reactive.function.client.WebClient} - * configuration: - * - *
- *  @Bean
- *  WebClient webClient() {
- *      ServletBearerExchangeFilterFunction bearer = new ServletBearerExchangeFilterFunction();
- *      return WebClient.builder()
- *              .apply(bearer.oauth2Configuration())
- *              .build();
- *  }
- * 
- * - * @author Josh Cummings - * @since 5.2 - */ -public class ServletBearerExchangeFilterFunction - implements ExchangeFilterFunction, InitializingBean, DisposableBean { - - private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName(); - - private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName(); - - /** - * {@inheritDoc} - */ - @Override - public void afterPropertiesSet() throws Exception { - Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, - Operators.liftPublisher((s, sub) -> createRequestContextSubscriber(sub))); - } - - /** - * {@inheritDoc} - */ - @Override - public void destroy() throws Exception { - Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY); - } - - /** - * Configures the builder with {@link #defaultRequest()} and adds this as a {@link ExchangeFilterFunction} - * @return the {@link Consumer} to configure the builder - */ - public Consumer oauth2Configuration() { - return builder -> builder.defaultRequest(defaultRequest()).filter(this); - } - - /** - * Provides defaults for the {@link Authentication} using - * {@link SecurityContextHolder}. It also can default the {@link AbstractOAuth2Token} using the - * {@link #authentication(Authentication)}. - * @return the {@link Consumer} to populate the attributes - */ - public Consumer> defaultRequest() { - return spec -> spec.attributes(attrs -> { - populateDefaultAuthentication(attrs); - }); - } - - /** - * Modifies the {@link ClientRequest#attributes()} to include the {@link Authentication} used to - * look up and save the {@link AbstractOAuth2Token}. The value is defaulted in - * {@link ServletBearerExchangeFilterFunction#defaultRequest()} - * - * @param authentication the {@link Authentication} to use. - * @return the {@link Consumer} to populate the attributes - */ - public static Consumer> authentication(Authentication authentication) { - return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication); - } - - /** - * {@inheritDoc} - */ - @Override - public Mono filter(ClientRequest request, ExchangeFunction next) { - return mergeRequestAttributesIfNecessary(request) - .filter(req -> req.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) - .map(req -> getOAuth2Token(req.attributes())) - .map(token -> bearer(request, token)) - .flatMap(next::exchange) - .switchIfEmpty(Mono.defer(() -> next.exchange(request))); - } - - private Mono mergeRequestAttributesIfNecessary(ClientRequest request) { - if (request.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) { - return Mono.just(request); - } - - return mergeRequestAttributesFromContext(request); - } - - private Mono mergeRequestAttributesFromContext(ClientRequest request) { - ClientRequest.Builder builder = ClientRequest.from(request); - return Mono.subscriberContext() - .map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))) - .map(ClientRequest.Builder::build); - } - - private void populateRequestAttributes(Map attrs, Context ctx) { - RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx); - if (holder == null) { - return; - } - if (holder.getAuthentication() != null) { - attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, holder.getAuthentication()); - } - } - - private AbstractOAuth2Token getOAuth2Token(Map attrs) { - Authentication authentication = (Authentication) attrs.get(AUTHENTICATION_ATTR_NAME); - if (authentication.getCredentials() instanceof AbstractOAuth2Token) { - return (AbstractOAuth2Token) authentication.getCredentials(); - } - return null; - } - - private ClientRequest bearer(ClientRequest request, AbstractOAuth2Token token) { - return ClientRequest.from(request) - .headers(headers -> headers.setBearerAuth(token.getTokenValue())) - .build(); - } - - private CoreSubscriber createRequestContextSubscriber(CoreSubscriber delegate) { - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - return new RequestContextSubscriber<>(delegate, authentication); - } - - private void populateDefaultAuthentication(Map attrs) { - if (attrs.containsKey(AUTHENTICATION_ATTR_NAME)) { - return; - } - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication); - } - - private static class RequestContextDataHolder { - private final Authentication authentication; - - RequestContextDataHolder(Authentication authentication) { - this.authentication = authentication; - } - - public Authentication getAuthentication() { - return this.authentication; - } - } - - private static class RequestContextSubscriber implements CoreSubscriber { - private static final String REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME = - RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER"); - - private CoreSubscriber delegate; - private final Context context; - - private RequestContextSubscriber(CoreSubscriber delegate, - Authentication authentication) { - - this.delegate = delegate; - Context parentContext = this.delegate.currentContext(); - Context context; - if (authentication == null || parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME)) { - context = parentContext; - } else { - context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME, - new RequestContextDataHolder(authentication)); - } - - this.context = context; - } - - @Nullable - static RequestContextDataHolder getRequestContext(Context ctx) { - return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME, null); - } - - @Override - public Context currentContext() { - return this.context; - } - - @Override - public void onSubscribe(Subscription s) { - this.delegate.onSubscribe(s); - } - - @Override - public void onNext(T t) { - this.delegate.onNext(t); - } - - @Override - public void onError(Throwable t) { - this.delegate.onError(t); - } - - @Override - public void onComplete() { - this.delegate.onComplete(); - } - } -} diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunction.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunction.java similarity index 60% rename from oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunction.java rename to oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunction.java index 0cb37fa85cb..700531d588c 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunction.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunction.java @@ -14,16 +14,11 @@ * limitations under the License. */ -package org.springframework.security.oauth2.server.resource.web.server; - -import java.util.Map; -import java.util.function.Consumer; +package org.springframework.security.oauth2.server.resource.web.reactive.function.client; import reactor.core.publisher.Mono; -import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; -import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.oauth2.core.AbstractOAuth2Token; @@ -52,52 +47,22 @@ * @author Josh Cummings * @since 5.2 */ -public class ServerBearerExchangeFilterFunction +public final class ServerBearerExchangeFilterFunction implements ExchangeFilterFunction { - private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName(); - - private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", - AuthorityUtils.createAuthorityList("ROLE_USER")); - - /** - * Modifies the {@link ClientRequest#attributes()} to include the {@link Authentication} to be used for - * providing the Bearer Token. Example usage: - * - *
-	 * WebClient webClient = WebClient.builder()
-	 *    .filter(new ServerBearerExchangeFilterFunction())
-	 *    .build();
-	 * Mono response = webClient
-	 *    .get()
-	 *    .uri(uri)
-	 *    .attributes(authentication(authentication))
-	 *    // ...
-	 *    .retrieve()
-	 *    .bodyToMono(String.class);
-	 * 
- * @param authentication the {@link Authentication} to use - * @return the {@link Consumer} to populate the client request attributes - */ - public static Consumer> authentication(Authentication authentication) { - return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication); - } - /** * {@inheritDoc} */ @Override public Mono filter(ClientRequest request, ExchangeFunction next) { - return oauth2Token(request.attributes()) - .map(oauth2Token -> bearer(request, oauth2Token)) + return oauth2Token() + .map(token -> bearer(request, token)) .defaultIfEmpty(request) .flatMap(next::exchange); } - private Mono oauth2Token(Map attrs) { - return Mono.justOrEmpty(attrs.get(AUTHENTICATION_ATTR_NAME)) - .cast(Authentication.class) - .switchIfEmpty(currentAuthentication()) + private Mono oauth2Token() { + return currentAuthentication() .filter(authentication -> authentication.getCredentials() instanceof AbstractOAuth2Token) .map(Authentication::getCredentials) .cast(AbstractOAuth2Token.class); @@ -105,8 +70,7 @@ private Mono oauth2Token(Map attrs) { private Mono currentAuthentication() { return ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .defaultIfEmpty(ANONYMOUS_USER_TOKEN); + .map(SecurityContext::getAuthentication); } private ClientRequest bearer(ClientRequest request, AbstractOAuth2Token token) { diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunction.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunction.java new file mode 100644 index 00000000000..e85540761be --- /dev/null +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunction.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.resource.web.reactive.function.client; + +import reactor.core.publisher.Mono; + +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.ExchangeFunction; + +/** + * An {@link ExchangeFilterFunction} that adds the + * Bearer Token + * from an existing {@link AbstractOAuth2Token} tied to the current {@link Authentication}. + * + * Suitable for Servlet applications, applying it to a typical {@link org.springframework.web.reactive.function.client.WebClient} + * configuration: + * + *
+ *  @Bean
+ *  WebClient webClient() {
+ *      ServletBearerExchangeFilterFunction bearer = new ServletBearerExchangeFilterFunction();
+ *      return WebClient.builder()
+ *              .filter(bearer).build();
+ *  }
+ * 
+ * + * @author Josh Cummings + * @since 5.2 + */ +public final class ServletBearerExchangeFilterFunction + implements ExchangeFilterFunction { + + /** + * {@inheritDoc} + */ + @Override + public Mono filter(ClientRequest request, ExchangeFunction next) { + return oauth2Token() + .map(token -> bearer(request, token)) + .defaultIfEmpty(request) + .flatMap(next::exchange); + } + + private Mono oauth2Token() { + return currentAuthentication() + .filter(authentication -> authentication.getCredentials() instanceof AbstractOAuth2Token) + .map(Authentication::getCredentials) + .cast(AbstractOAuth2Token.class); + } + + private Mono currentAuthentication() { + return Mono.justOrEmpty(SecurityContextHolder.getContext().getAuthentication()); + } + + private ClientRequest bearer(ClientRequest request, AbstractOAuth2Token token) { + return ClientRequest.from(request) + .headers(headers -> headers.setBearerAuth(token.getTokenValue())) + .build(); + } +} diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunctionTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunctionTests.java similarity index 85% rename from oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunctionTests.java rename to oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunctionTests.java index 0a5ac9b0f55..22bdb72bcde 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunctionTests.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.security.oauth2.server.resource.web.server; +package org.springframework.security.oauth2.server.resource.web.reactive.function.client; import java.net.URI; import java.time.Duration; @@ -25,6 +25,7 @@ import org.junit.Test; import org.springframework.http.HttpHeaders; +import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -34,7 +35,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.http.HttpMethod.GET; -import static org.springframework.security.oauth2.server.resource.web.ServletBearerExchangeFilterFunction.authentication; /** * Tests for {@link ServerBearerExchangeFilterFunction} @@ -80,26 +80,30 @@ public void filterWhenAuthenticatedThenAuthorizationHeaderNull() throws Exceptio .isEqualTo("Bearer " + this.accessToken.getTokenValue()); } + // gh-7353 @Test - public void filterWhenAuthenticationAttributeThenAuthorizationHeader() { + public void filterWhenAuthenticatedWithOtherTokenThenAuthorizationHeaderNull() throws Exception { ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(authentication(this.authentication)) .build(); - this.function.filter(request, this.exchange).block(); + TestingAuthenticationToken token = new TestingAuthenticationToken("user", "pass"); + this.function.filter(request, this.exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(token)) + .block(); assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) - .isEqualTo("Bearer " + this.accessToken.getTokenValue()); + .isNull(); } @Test public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .header(HttpHeaders.AUTHORIZATION, "Existing") - .attributes(authentication(this.authentication)) .build(); - this.function.filter(request, this.exchange).block(); + this.function.filter(request, this.exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)) + .block(); HttpHeaders headers = this.exchange.getRequest().headers(); assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunctionTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunctionTests.java similarity index 82% rename from oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunctionTests.java rename to oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunctionTests.java index f1217a63a5f..adeb9be1852 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunctionTests.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.security.oauth2.server.resource.web; +package org.springframework.security.oauth2.server.resource.web.reactive.function.client; import java.net.URI; import java.time.Duration; @@ -28,15 +28,16 @@ import org.mockito.junit.MockitoJUnitRunner; import org.springframework.http.HttpHeaders; +import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken; +import org.springframework.security.oauth2.server.resource.web.MockExchangeFunction; import org.springframework.web.reactive.function.client.ClientRequest; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.http.HttpMethod.GET; -import static org.springframework.security.oauth2.server.resource.web.ServletBearerExchangeFilterFunction.authentication; /** * Tests for {@link ServletBearerExchangeFilterFunction} @@ -53,6 +54,7 @@ public class ServletBearerExchangeFilterFunctionTests { "token-0", Instant.now(), Instant.now().plus(Duration.ofDays(1))); + private Authentication authentication = new AbstractOAuth2TokenAuthenticationToken(accessToken) { @Override public Map getTokenAttributes() { @@ -72,13 +74,15 @@ public void filterWhenUnauthenticatedThenAuthorizationHeaderNull() { this.function.filter(request, this.exchange).block(); - assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) + .isNull(); } + // gh-7353 @Test - public void filterWhenAuthenticatedThenAuthorizationHeaderNull() throws Exception { - this.function.afterPropertiesSet(); - SecurityContextHolder.getContext().setAuthentication(this.authentication); + public void filterWhenAuthenticatedWithOtherTokenThenAuthorizationHeaderNull() throws Exception { + TestingAuthenticationToken token = new TestingAuthenticationToken("user", "pass"); + SecurityContextHolder.getContext().setAuthentication(token); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .build(); @@ -86,13 +90,14 @@ public void filterWhenAuthenticatedThenAuthorizationHeaderNull() throws Exceptio this.function.filter(request, this.exchange).block(); assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) - .isEqualTo("Bearer " + this.accessToken.getTokenValue()); + .isNull(); } @Test - public void filterWhenAuthenticationAttributeThenAuthorizationHeader() { + public void filterWhenAuthenticatedThenAuthorizationHeader() throws Exception { + SecurityContextHolder.getContext().setAuthentication(this.authentication); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(authentication(this.authentication)) .build(); this.function.filter(request, this.exchange).block(); @@ -102,10 +107,11 @@ public void filterWhenAuthenticationAttributeThenAuthorizationHeader() { } @Test - public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { + public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() throws Exception { + SecurityContextHolder.getContext().setAuthentication(this.authentication); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .header(HttpHeaders.AUTHORIZATION, "Existing") - .attributes(authentication(this.authentication)) .build(); this.function.filter(request, this.exchange).block();