-
Notifications
You must be signed in to change notification settings - Fork 6.1k
ServletOAuth2AuthorizedClientExchangeFilterFunction supports chaining #6526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
package org.springframework.security.oauth2.client.web.reactive.function.client; | ||
|
||
import org.reactivestreams.Subscription; | ||
import org.springframework.http.HttpHeaders; | ||
import org.springframework.http.HttpMethod; | ||
import org.springframework.http.MediaType; | ||
|
@@ -44,8 +45,12 @@ | |
import org.springframework.web.reactive.function.client.ExchangeFilterFunction; | ||
import org.springframework.web.reactive.function.client.ExchangeFunction; | ||
import org.springframework.web.reactive.function.client.WebClient; | ||
import reactor.core.CoreSubscriber; | ||
import reactor.core.publisher.Hooks; | ||
import reactor.core.publisher.Mono; | ||
import reactor.core.publisher.Operators; | ||
import reactor.core.scheduler.Schedulers; | ||
import reactor.util.context.Context; | ||
|
||
import javax.servlet.http.HttpServletRequest; | ||
import javax.servlet.http.HttpServletResponse; | ||
|
@@ -108,6 +113,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement | |
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName(); | ||
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName(); | ||
|
||
private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName(); | ||
|
||
private Clock clock = Clock.systemUTC(); | ||
|
||
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); | ||
|
@@ -123,11 +130,14 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement | |
|
||
private String defaultClientRegistrationId; | ||
|
||
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {} | ||
public ServletOAuth2AuthorizedClientExchangeFilterFunction() { | ||
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub))); | ||
} | ||
|
||
public ServletOAuth2AuthorizedClientExchangeFilterFunction( | ||
ClientRegistrationRepository clientRegistrationRepository, | ||
OAuth2AuthorizedClientRepository authorizedClientRepository) { | ||
this(); | ||
this.clientRegistrationRepository = clientRegistrationRepository; | ||
this.authorizedClientRepository = authorizedClientRepository; | ||
} | ||
|
@@ -266,15 +276,36 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) { | |
|
||
@Override | ||
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) { | ||
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME) | ||
.map(OAuth2AuthorizedClient.class::cast); | ||
return Mono.justOrEmpty(attribute) | ||
.flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient)) | ||
return Mono.just(request) | ||
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) | ||
.switchIfEmpty(mergeRequestAttributesFromContext(request)) | ||
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) | ||
.flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes()))) | ||
.map(authorizedClient -> bearer(request, authorizedClient)) | ||
.flatMap(next::exchange) | ||
.switchIfEmpty(next.exchange(request)); | ||
} | ||
|
||
private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) { | ||
return Mono.just(ClientRequest.from(request)) | ||
.flatMap(builder -> Mono.subscriberContext() | ||
.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx)))) | ||
.map(ClientRequest.Builder::build); | ||
} | ||
|
||
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) { | ||
if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) { | ||
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME)); | ||
} | ||
if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) { | ||
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME)); | ||
} | ||
if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) { | ||
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME)); | ||
} | ||
populateDefaultOAuth2AuthorizedClient(attrs); | ||
} | ||
|
||
private void populateDefaultRequestResponse(Map<String, Object> attrs) { | ||
if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey( | ||
HTTP_SERVLET_RESPONSE_ATTR_NAME)) { | ||
|
@@ -425,6 +456,19 @@ private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient autho | |
.build(); | ||
} | ||
|
||
private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) { | ||
HttpServletRequest request = null; | ||
HttpServletResponse response = null; | ||
ServletRequestAttributes requestAttributes = | ||
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); | ||
if (requestAttributes != null) { | ||
request = requestAttributes.getRequest(); | ||
response = requestAttributes.getResponse(); | ||
} | ||
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); | ||
return new RequestContextSubscriber<>(delegate, request, response, authentication); | ||
} | ||
|
||
private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) { | ||
return BodyInserters | ||
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue()) | ||
|
@@ -498,4 +542,55 @@ private UnsupportedOperationException unsupported() { | |
return new UnsupportedOperationException("Not Supported"); | ||
} | ||
} | ||
|
||
private static class RequestContextSubscriber<T> implements CoreSubscriber<T> { | ||
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME"); | ||
private final CoreSubscriber<T> delegate; | ||
private final HttpServletRequest request; | ||
private final HttpServletResponse response; | ||
private final Authentication authentication; | ||
|
||
private RequestContextSubscriber(CoreSubscriber<T> delegate, | ||
HttpServletRequest request, | ||
HttpServletResponse response, | ||
Authentication authentication) { | ||
this.delegate = delegate; | ||
this.request = request; | ||
this.response = response; | ||
this.authentication = authentication; | ||
} | ||
|
||
@Override | ||
public Context currentContext() { | ||
Context context = this.delegate.currentContext(); | ||
if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) { | ||
return context; | ||
} | ||
return Context.of( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jgrandja @rwinch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Context javadoc states
but this subscriber creates 4 keys (occupies almost all "optimmized" implementations of Context) Maybe it would be better to put all this attrs into some holder so you can reduce number of keys from 4 to 1. And it allows remove extra I mean that it is not just spring security who populates context and it could impact overall performance if context is populated by other libraries frequently There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @robotmrv Thanks for the report. Can you create a ticket or a a PR for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and here is PR #7235 |
||
CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE, | ||
HTTP_SERVLET_REQUEST_ATTR_NAME, this.request, | ||
HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response, | ||
AUTHENTICATION_ATTR_NAME, this.authentication); | ||
} | ||
|
||
@Override | ||
public void onSubscribe(Subscription s) { | ||
this.delegate.onSubscribe(s); | ||
} | ||
|
||
@Override | ||
public void onNext(T t) { | ||
this.delegate.onNext(t); | ||
} | ||
|
||
@Override | ||
public void onError(Throwable t) { | ||
this.delegate.onError(t); | ||
} | ||
|
||
@Override | ||
public void onComplete() { | ||
this.delegate.onComplete(); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want
Context
to override potentially existingattrs
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I original had
attrs.putIfAbsent()
but didn't see any reason for it. Do you anticipate a potential issue with overriding? The attributes in theContext
will always be what the filter needs so I don't see any issue with this. What do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applied
putIfAbsent()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concern was if someone had a custom
Authentication
they were working with, it might be explicitly set as an attribute. Perhaps this is so that a global access token is used for client_credentials access tokens.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Ok so
putIfAbsent()
is needed...and it was updated in latest commit