Skip to content

Commit ffc43e0

Browse files
committed
Fix NPE in RequestContextSubscriber
RequestContextSubscriber could cause NPE if Mono/Flux.subscribe() was invoked outside of Web Context. In addition it replaced source Context with its own without respect to old data. Now Request Context Data is Propagated within holder class and it is added to existing reactor Context if Holder is not empty. Fixes gh-7228
1 parent 1de885e commit ffc43e0

File tree

2 files changed

+168
-32
lines changed

2 files changed

+168
-32
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.reactivestreams.Subscription;
2020
import org.springframework.beans.factory.DisposableBean;
2121
import org.springframework.beans.factory.InitializingBean;
22+
import org.springframework.lang.Nullable;
2223
import org.springframework.security.authentication.AnonymousAuthenticationToken;
2324
import org.springframework.security.core.Authentication;
2425
import org.springframework.security.core.GrantedAuthority;
@@ -95,6 +96,7 @@
9596
*
9697
* @author Rob Winch
9798
* @author Joe Grandja
99+
* @author Roman Matiushchenko
98100
* @since 5.1
99101
* @see OAuth2AuthorizedClientManager
100102
*/
@@ -174,7 +176,7 @@ private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManage
174176

175177
@Override
176178
public void afterPropertiesSet() throws Exception {
177-
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
179+
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.liftPublisher((s, sub) -> createRequestContextSubscriberIfNecessary(sub)));
178180
}
179181

180182
@Override
@@ -378,14 +380,22 @@ private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest requ
378380
}
379381

380382
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
381-
if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
382-
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
383-
}
384-
if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
385-
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
386-
}
387-
if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
388-
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
383+
RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx);
384+
if (holder != null) {
385+
HttpServletRequest request = holder.getRequest();
386+
if (request != null) {
387+
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
388+
}
389+
390+
HttpServletResponse response = holder.getResponse();
391+
if (response != null) {
392+
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
393+
}
394+
395+
Authentication authentication = holder.getAuthentication();
396+
if (authentication != null) {
397+
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
398+
}
389399
}
390400
}
391401

@@ -472,7 +482,7 @@ private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient autho
472482
.build();
473483
}
474484

475-
private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
485+
<T> CoreSubscriber<T> createRequestContextSubscriberIfNecessary(CoreSubscriber<T> delegate) {
476486
HttpServletRequest request = null;
477487
HttpServletResponse response = null;
478488
ServletRequestAttributes requestAttributes =
@@ -482,6 +492,10 @@ private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> d
482492
response = requestAttributes.getResponse();
483493
}
484494
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
495+
if (authentication == null && request == null && response == null) {
496+
//do not need to create RequestContextSubscriber with empty data
497+
return delegate;
498+
}
485499
return new RequestContextSubscriber<>(delegate, request, response, authentication);
486500
}
487501

@@ -553,34 +567,37 @@ private UnsupportedOperationException unsupported() {
553567
}
554568
}
555569

556-
private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
557-
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
570+
static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
571+
static final String REQUEST_CONTEXT_DATA_HOLDER =
572+
RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER");
558573
private final CoreSubscriber<T> delegate;
559-
private final HttpServletRequest request;
560-
private final HttpServletResponse response;
561-
private final Authentication authentication;
574+
private final Context context;
562575

563-
private RequestContextSubscriber(CoreSubscriber<T> delegate,
564-
HttpServletRequest request,
565-
HttpServletResponse response,
566-
Authentication authentication) {
576+
RequestContextSubscriber(CoreSubscriber<T> delegate,
577+
HttpServletRequest request,
578+
HttpServletResponse response,
579+
Authentication authentication) {
567580
this.delegate = delegate;
568-
this.request = request;
569-
this.response = response;
570-
this.authentication = authentication;
581+
582+
Context parentContext = this.delegate.currentContext();
583+
Context context;
584+
if (parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER)) {
585+
context = parentContext;
586+
} else {
587+
context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER, new RequestContextDataHolder(request, response, authentication));
588+
}
589+
590+
this.context = context;
591+
}
592+
593+
@Nullable
594+
private static RequestContextDataHolder getRequestContext(Context ctx) {
595+
return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER, null);
571596
}
572597

573598
@Override
574599
public Context currentContext() {
575-
Context context = this.delegate.currentContext();
576-
if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
577-
return context;
578-
}
579-
return Context.of(
580-
CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
581-
HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
582-
HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
583-
AUTHENTICATION_ATTR_NAME, this.authentication);
600+
return this.context;
584601
}
585602

586603
@Override
@@ -603,4 +620,33 @@ public void onComplete() {
603620
this.delegate.onComplete();
604621
}
605622
}
623+
624+
static class RequestContextDataHolder {
625+
private final HttpServletRequest request;
626+
private final HttpServletResponse response;
627+
private final Authentication authentication;
628+
629+
RequestContextDataHolder(@Nullable HttpServletRequest request,
630+
@Nullable HttpServletResponse response,
631+
@Nullable Authentication authentication) {
632+
this.request = request;
633+
this.response = response;
634+
this.authentication = authentication;
635+
}
636+
637+
@Nullable
638+
private HttpServletRequest getRequest() {
639+
return this.request;
640+
}
641+
642+
@Nullable
643+
private HttpServletResponse getResponse() {
644+
return this.response;
645+
}
646+
647+
@Nullable
648+
private Authentication getAuthentication() {
649+
return this.authentication;
650+
}
651+
}
606652
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@
7272
import org.springframework.web.reactive.function.BodyInserter;
7373
import org.springframework.web.reactive.function.client.ClientRequest;
7474
import org.springframework.web.reactive.function.client.WebClient;
75+
import reactor.core.CoreSubscriber;
76+
import reactor.core.publisher.BaseSubscriber;
77+
import reactor.core.publisher.Mono;
78+
import reactor.util.context.Context;
7579

7680
import java.net.URI;
7781
import java.time.Duration;
@@ -84,6 +88,7 @@
8488
import java.util.function.Consumer;
8589

8690
import static org.assertj.core.api.Assertions.assertThat;
91+
import static org.assertj.core.api.Assertions.assertThatCode;
8792
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
8893
import static org.mockito.Mockito.*;
8994
import static org.springframework.http.HttpMethod.GET;
@@ -144,9 +149,10 @@ public void setup() {
144149
}
145150

146151
@After
147-
public void cleanup() {
152+
public void cleanup() throws Exception {
148153
SecurityContextHolder.clearContext();
149154
RequestContextHolder.resetRequestAttributes();
155+
this.function.destroy();
150156
}
151157

152158
@Test
@@ -633,6 +639,90 @@ public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsN
633639
assertThat(getBody(request)).isEmpty();
634640
}
635641

642+
// gh-7228
643+
@Test
644+
public void afterPropertiesSetWhenHooksInitAndOutsideWebSecurityContextThenShouldNotThrowException() throws Exception {
645+
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
646+
assertThatCode(() -> Mono.subscriberContext().block())
647+
.as("RequestContext Hook brakes application outside of web/security context")
648+
.doesNotThrowAnyException();
649+
}
650+
651+
@Test
652+
public void createRequestContextSubscriberIfNecessaryWhenOutsideWebSecurityContextThenReturnOriginalSubscriber() throws Exception {
653+
BaseSubscriber<Object> originalSubscriber = new BaseSubscriber<Object>() {};
654+
CoreSubscriber<Object> resultSubscriber = this.function.createRequestContextSubscriberIfNecessary(originalSubscriber);
655+
assertThat(resultSubscriber).isSameAs(originalSubscriber);
656+
}
657+
658+
// gh-7228
659+
@Test
660+
public void createRequestContextSubscriberWhenRequestResponseProvidedThenCreateWithParentContext() throws Exception {
661+
testRequestContextSubscriber(new MockHttpServletRequest(), new MockHttpServletResponse(), null);
662+
}
663+
664+
// gh-7228
665+
@Test
666+
public void createRequestContextSubscriberWhenAuthenticationProvidedThenCreateWithParentContext() throws Exception {
667+
testRequestContextSubscriber(null, null, this.authentication);
668+
}
669+
670+
@Test
671+
public void createRequestContextSubscriberWhenParentContextHasDataHolderThenShouldReuseParentContext() throws Exception {
672+
RequestContextDataHolder testValue = new RequestContextDataHolder(null, null, null);
673+
final Context parentContext = Context.of(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, testValue);
674+
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
675+
@Override
676+
public Context currentContext() {
677+
return parentContext;
678+
}
679+
};
680+
681+
RequestContextSubscriber<Object> requestContextSubscriber =
682+
new RequestContextSubscriber<>(parent, null, null, authentication);
683+
684+
Context resultContext = requestContextSubscriber.currentContext();
685+
686+
assertThat(resultContext)
687+
.describedAs("parent context was replaced")
688+
.isSameAs(parentContext);
689+
}
690+
691+
private void testRequestContextSubscriber(MockHttpServletRequest servletRequest,
692+
MockHttpServletResponse servletResponse,
693+
Authentication authentication) {
694+
String testKey = "test_key";
695+
String testValue = "test_value";
696+
697+
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
698+
@Override
699+
public Context currentContext() {
700+
return Context.of(testKey, testValue);
701+
}
702+
};
703+
704+
RequestContextSubscriber<Object> requestContextSubscriber =
705+
new RequestContextSubscriber<>(parent, servletRequest, servletResponse, authentication);
706+
707+
Context resultContext = requestContextSubscriber.currentContext();
708+
709+
assertThat(resultContext)
710+
.describedAs("result context is null")
711+
.isNotNull();
712+
713+
assertThat(resultContext.getOrEmpty(testKey))
714+
.describedAs("context is replaced")
715+
.hasValue(testValue);
716+
717+
Object dataHolder = resultContext.getOrDefault(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, null);
718+
assertThat(dataHolder)
719+
.describedAs("context is not populated with REQUEST_CONTEXT_DATA_HOLDER")
720+
.isNotNull()
721+
.hasFieldOrPropertyWithValue("request", servletRequest)
722+
.hasFieldOrPropertyWithValue("response", servletResponse)
723+
.hasFieldOrPropertyWithValue("authentication", authentication);
724+
}
725+
636726
private static String getBody(ClientRequest request) {
637727
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
638728
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));

0 commit comments

Comments
 (0)