Skip to content

Commit 0c27f64

Browse files
committed
ServletOAuth2AuthorizedClientExchangeFilterFunction supports chaining
Fixes gh-6483
1 parent 0c2a7e0 commit 0c27f64

File tree

2 files changed

+231
-12
lines changed

2 files changed

+231
-12
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 113 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
package org.springframework.security.oauth2.client.web.reactive.function.client;
1818

19+
import org.reactivestreams.Subscription;
20+
import org.springframework.beans.factory.DisposableBean;
21+
import org.springframework.beans.factory.InitializingBean;
1922
import org.springframework.http.HttpHeaders;
2023
import org.springframework.http.HttpMethod;
2124
import org.springframework.http.MediaType;
@@ -44,8 +47,12 @@
4447
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
4548
import org.springframework.web.reactive.function.client.ExchangeFunction;
4649
import org.springframework.web.reactive.function.client.WebClient;
50+
import reactor.core.CoreSubscriber;
51+
import reactor.core.publisher.Hooks;
4752
import reactor.core.publisher.Mono;
53+
import reactor.core.publisher.Operators;
4854
import reactor.core.scheduler.Schedulers;
55+
import reactor.util.context.Context;
4956

5057
import javax.servlet.http.HttpServletRequest;
5158
import javax.servlet.http.HttpServletResponse;
@@ -98,7 +105,9 @@
98105
* @author Rob Winch
99106
* @since 5.1
100107
*/
101-
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
108+
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
109+
implements ExchangeFilterFunction, InitializingBean, DisposableBean {
110+
102111
/**
103112
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
104113
*/
@@ -108,6 +117,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
108117
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
109118
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
110119

120+
private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
121+
111122
private Clock clock = Clock.systemUTC();
112123

113124
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
@@ -123,7 +134,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
123134

124135
private String defaultClientRegistrationId;
125136

126-
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
137+
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
138+
}
127139

128140
public ServletOAuth2AuthorizedClientExchangeFilterFunction(
129141
ClientRegistrationRepository clientRegistrationRepository,
@@ -132,6 +144,16 @@ public ServletOAuth2AuthorizedClientExchangeFilterFunction(
132144
this.authorizedClientRepository = authorizedClientRepository;
133145
}
134146

147+
@Override
148+
public void afterPropertiesSet() throws Exception {
149+
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
150+
}
151+
152+
@Override
153+
public void destroy() throws Exception {
154+
Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY);
155+
}
156+
135157
/**
136158
* Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
137159
* client_credentials grant.
@@ -266,15 +288,36 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
266288

267289
@Override
268290
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
269-
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
270-
.map(OAuth2AuthorizedClient.class::cast);
271-
return Mono.justOrEmpty(attribute)
272-
.flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient))
291+
return Mono.just(request)
292+
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
293+
.switchIfEmpty(mergeRequestAttributesFromContext(request))
294+
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
295+
.flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes())))
273296
.map(authorizedClient -> bearer(request, authorizedClient))
274297
.flatMap(next::exchange)
275298
.switchIfEmpty(next.exchange(request));
276299
}
277300

301+
private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) {
302+
return Mono.just(ClientRequest.from(request))
303+
.flatMap(builder -> Mono.subscriberContext()
304+
.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))))
305+
.map(ClientRequest.Builder::build);
306+
}
307+
308+
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
309+
if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
310+
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
311+
}
312+
if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
313+
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
314+
}
315+
if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
316+
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
317+
}
318+
populateDefaultOAuth2AuthorizedClient(attrs);
319+
}
320+
278321
private void populateDefaultRequestResponse(Map<String, Object> attrs) {
279322
if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey(
280323
HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
@@ -435,6 +478,19 @@ private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient autho
435478
.build();
436479
}
437480

481+
private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
482+
HttpServletRequest request = null;
483+
HttpServletResponse response = null;
484+
ServletRequestAttributes requestAttributes =
485+
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
486+
if (requestAttributes != null) {
487+
request = requestAttributes.getRequest();
488+
response = requestAttributes.getResponse();
489+
}
490+
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
491+
return new RequestContextSubscriber<>(delegate, request, response, authentication);
492+
}
493+
438494
private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
439495
return BodyInserters
440496
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
@@ -508,4 +564,55 @@ private UnsupportedOperationException unsupported() {
508564
return new UnsupportedOperationException("Not Supported");
509565
}
510566
}
567+
568+
private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
569+
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
570+
private final CoreSubscriber<T> delegate;
571+
private final HttpServletRequest request;
572+
private final HttpServletResponse response;
573+
private final Authentication authentication;
574+
575+
private RequestContextSubscriber(CoreSubscriber<T> delegate,
576+
HttpServletRequest request,
577+
HttpServletResponse response,
578+
Authentication authentication) {
579+
this.delegate = delegate;
580+
this.request = request;
581+
this.response = response;
582+
this.authentication = authentication;
583+
}
584+
585+
@Override
586+
public Context currentContext() {
587+
Context context = this.delegate.currentContext();
588+
if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
589+
return context;
590+
}
591+
return Context.of(
592+
CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
593+
HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
594+
HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
595+
AUTHENTICATION_ATTR_NAME, this.authentication);
596+
}
597+
598+
@Override
599+
public void onSubscribe(Subscription s) {
600+
this.delegate.onSubscribe(s);
601+
}
602+
603+
@Override
604+
public void onNext(T t) {
605+
this.delegate.onNext(t);
606+
}
607+
608+
@Override
609+
public void onError(Throwable t) {
610+
this.delegate.onError(t);
611+
}
612+
613+
@Override
614+
public void onComplete() {
615+
this.delegate.onComplete();
616+
}
617+
}
511618
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 118 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,11 @@
7474
import java.util.Optional;
7575
import java.util.function.Consumer;
7676

77-
import static org.assertj.core.api.Assertions.*;
77+
import static org.assertj.core.api.Assertions.assertThat;
78+
import static org.assertj.core.api.Assertions.assertThatCode;
7879
import static org.mockito.ArgumentMatchers.any;
7980
import static org.mockito.ArgumentMatchers.eq;
80-
import static org.mockito.Mockito.mock;
81-
import static org.mockito.Mockito.never;
82-
import static org.mockito.Mockito.verify;
83-
import static org.mockito.Mockito.verifyZeroInteractions;
84-
import static org.mockito.Mockito.when;
81+
import static org.mockito.Mockito.*;
8582
import static org.springframework.http.HttpMethod.GET;
8683
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
8784

@@ -647,6 +644,121 @@ public void filterWhenNotExpiredThenShouldRefreshFalse() {
647644
assertThat(getBody(request0)).isEmpty();
648645
}
649646

647+
// gh-6483
648+
@Test
649+
public void filterWhenChainedThenDefaultsStillAvailable() throws Exception {
650+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
651+
this.clientRegistrationRepository, this.authorizedClientRepository);
652+
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
653+
this.function.setDefaultOAuth2AuthorizedClient(true);
654+
655+
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
656+
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
657+
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
658+
659+
OAuth2User user = mock(OAuth2User.class);
660+
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
661+
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
662+
user, authorities, this.registration.getRegistrationId());
663+
SecurityContextHolder.getContext().setAuthentication(authentication);
664+
665+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
666+
this.registration, "principalName", this.accessToken);
667+
when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
668+
eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);
669+
670+
// Default request attributes set
671+
final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
672+
.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();
673+
674+
// Default request attributes NOT set
675+
final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build();
676+
677+
this.function.filter(request1, this.exchange)
678+
.flatMap(response -> this.function.filter(request2, this.exchange))
679+
.block();
680+
681+
this.function.destroy(); // Hooks.onLastOperator() released
682+
683+
List<ClientRequest> requests = this.exchange.getRequests();
684+
assertThat(requests).hasSize(2);
685+
686+
ClientRequest request = requests.get(0);
687+
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
688+
assertThat(request.url().toASCIIString()).isEqualTo("https://example1.com");
689+
assertThat(request.method()).isEqualTo(HttpMethod.GET);
690+
assertThat(getBody(request)).isEmpty();
691+
692+
request = requests.get(1);
693+
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
694+
assertThat(request.url().toASCIIString()).isEqualTo("https://example2.com");
695+
assertThat(request.method()).isEqualTo(HttpMethod.GET);
696+
assertThat(getBody(request)).isEmpty();
697+
}
698+
699+
@Test
700+
public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() throws Exception {
701+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
702+
this.clientRegistrationRepository, this.authorizedClientRepository);
703+
// this.function.afterPropertiesSet(); // Hooks.onLastOperator() NOT initialized
704+
this.function.setDefaultOAuth2AuthorizedClient(true);
705+
706+
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
707+
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
708+
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
709+
710+
OAuth2User user = mock(OAuth2User.class);
711+
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
712+
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
713+
user, authorities, this.registration.getRegistrationId());
714+
SecurityContextHolder.getContext().setAuthentication(authentication);
715+
716+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
717+
718+
this.function.filter(request, this.exchange).block();
719+
720+
List<ClientRequest> requests = this.exchange.getRequests();
721+
assertThat(requests).hasSize(1);
722+
723+
request = requests.get(0);
724+
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
725+
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
726+
assertThat(request.method()).isEqualTo(HttpMethod.GET);
727+
assertThat(getBody(request)).isEmpty();
728+
}
729+
730+
@Test
731+
public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception {
732+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
733+
this.clientRegistrationRepository, this.authorizedClientRepository);
734+
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
735+
this.function.destroy(); // Hooks.onLastOperator() released
736+
this.function.setDefaultOAuth2AuthorizedClient(true);
737+
738+
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
739+
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
740+
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
741+
742+
OAuth2User user = mock(OAuth2User.class);
743+
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
744+
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
745+
user, authorities, this.registration.getRegistrationId());
746+
SecurityContextHolder.getContext().setAuthentication(authentication);
747+
748+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
749+
750+
this.function.filter(request, this.exchange).block();
751+
752+
List<ClientRequest> requests = this.exchange.getRequests();
753+
assertThat(requests).hasSize(1);
754+
755+
request = requests.get(0);
756+
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
757+
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
758+
assertThat(request.method()).isEqualTo(HttpMethod.GET);
759+
assertThat(getBody(request)).isEmpty();
760+
}
761+
650762
private static String getBody(ClientRequest request) {
651763
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
652764
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));

0 commit comments

Comments
 (0)