Skip to content

ServletOAuth2AuthorizedClientExchangeFilterFunction supports chaining #6526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.security.oauth2.client.web.reactive.function.client;

import org.reactivestreams.Subscription;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
Expand Down Expand Up @@ -44,8 +45,12 @@
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.core.scheduler.Schedulers;
import reactor.util.context.Context;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
Expand Down Expand Up @@ -108,6 +113,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();

private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();

private Clock clock = Clock.systemUTC();

private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
Expand All @@ -123,11 +130,14 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement

private String defaultClientRegistrationId;

public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
}

public ServletOAuth2AuthorizedClientExchangeFilterFunction(
ClientRegistrationRepository clientRegistrationRepository,
OAuth2AuthorizedClientRepository authorizedClientRepository) {
this();
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientRepository = authorizedClientRepository;
}
Expand Down Expand Up @@ -266,15 +276,36 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {

@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
.map(OAuth2AuthorizedClient.class::cast);
return Mono.justOrEmpty(attribute)
.flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient))
return Mono.just(request)
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
.switchIfEmpty(mergeRequestAttributesFromContext(request))
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
.flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes())))
.map(authorizedClient -> bearer(request, authorizedClient))
.flatMap(next::exchange)
.switchIfEmpty(next.exchange(request));
}

private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) {
return Mono.just(ClientRequest.from(request))
.flatMap(builder -> Mono.subscriberContext()
.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))))
.map(ClientRequest.Builder::build);
}

private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want Context to override potentially existing attrs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I original had attrs.putIfAbsent() but didn't see any reason for it. Do you anticipate a potential issue with overriding? The attributes in the Context will always be what the filter needs so I don't see any issue with this. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied putIfAbsent()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern was if someone had a custom Authentication they were working with, it might be explicitly set as an attribute. Perhaps this is so that a global access token is used for client_credentials access tokens.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Ok so putIfAbsent() is needed...and it was updated in latest commit

if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
}
if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
}
if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
}
populateDefaultOAuth2AuthorizedClient(attrs);
}

private void populateDefaultRequestResponse(Map<String, Object> attrs) {
if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey(
HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
Expand Down Expand Up @@ -425,6 +456,19 @@ private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient autho
.build();
}

private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
HttpServletRequest request = null;
HttpServletResponse response = null;
ServletRequestAttributes requestAttributes =
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (requestAttributes != null) {
request = requestAttributes.getRequest();
response = requestAttributes.getResponse();
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
return new RequestContextSubscriber<>(delegate, request, response, authentication);
}

private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
return BodyInserters
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
Expand Down Expand Up @@ -498,4 +542,55 @@ private UnsupportedOperationException unsupported() {
return new UnsupportedOperationException("Not Supported");
}
}

private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
private final CoreSubscriber<T> delegate;
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Authentication authentication;

private RequestContextSubscriber(CoreSubscriber<T> delegate,
HttpServletRequest request,
HttpServletResponse response,
Authentication authentication) {
this.delegate = delegate;
this.request = request;
this.response = response;
this.authentication = authentication;
}

@Override
public Context currentContext() {
Context context = this.delegate.currentContext();
if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
return context;
}
return Context.of(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgrandja @rwinch
This subscriber applies for each reactive chain by Hooks.onLastOperator() globally
Context does not alllow null values (see reactor/reactor-core#1800)
what if Mono.subscribe()/block() is executed out of security or web context (rabbitmq listener or just some scheduled task...)?

Copy link
Contributor

@robotmrv robotmrv Aug 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context javadoc states

 * Note that contexts are optimized for low cardinality key/value storage, and a user
 * might want to associate a dedicated mutable structure to a single key to represent his
 * own context instead of using multiple {@link #put}, which could be more costly.
 * Past five user key/value pair, the {@link Context} will use a copy-on-write
 * implementation backed by a new {@link java.util.Map} on each {@link #put}.

https://github.com/reactor/reactor-core/blob/master/reactor-core/src/main/java/reactor/util/context/Context.java#L36-L44

but this subscriber creates 4 keys (occupies almost all "optimmized" implementations of Context)
and as far as I understand CONTEXT_DEFAULTED_ATTR_NAME key is just a marker

Maybe it would be better to put all this attrs into some holder so you can reduce number of keys from 4 to 1. And it allows remove extra CONTEXT_DEFAULTED_ATTR_NAME

I mean that it is not just spring security who populates context and it could impact overall performance if context is populated by other libraries frequently

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robotmrv Thanks for the report. Can you create a ticket or a a PR for this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rwinch
I've created new issue
please see #7228

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here is PR #7235

CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
AUTHENTICATION_ATTR_NAME, this.authentication);
}

@Override
public void onSubscribe(Subscription s) {
this.delegate.onSubscribe(s);
}

@Override
public void onNext(T t) {
this.delegate.onNext(t);
}

@Override
public void onError(Throwable t) {
this.delegate.onError(t);
}

@Override
public void onComplete() {
this.delegate.onComplete();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,11 @@
import java.util.Optional;
import java.util.function.Consumer;

import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;
import static org.springframework.http.HttpMethod.GET;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;

Expand Down Expand Up @@ -572,6 +570,55 @@ public void filterWhenNotExpiredThenShouldRefreshFalse() {
assertThat(getBody(request0)).isEmpty();
}

// gh-6483
@Test
public void filterWhenChainedThenDefaultsStillAvailable() {
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
this.clientRegistrationRepository, this.authorizedClientRepository);
this.function.setDefaultOAuth2AuthorizedClient(true);

MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));

OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
user, authorities, this.registration.getRegistrationId());
SecurityContextHolder.getContext().setAuthentication(authentication);

OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration, "principalName", this.accessToken);
when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);

// Default request attributes set
final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();

// Default request attributes NOT set
final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build();

this.function.filter(request1, this.exchange)
.flatMap(response -> this.function.filter(request2, this.exchange))
.block();

List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(2);

ClientRequest request = requests.get(0);
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request.url().toASCIIString()).isEqualTo("https://example1.com");
assertThat(request.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request)).isEmpty();

request = requests.get(1);
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request.url().toASCIIString()).isEqualTo("https://example2.com");
assertThat(request.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request)).isEmpty();
}

private static String getBody(ClientRequest request) {
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));
Expand Down