16
16
17
17
package org .springframework .security .oauth2 .client .web .reactive .function .client ;
18
18
19
+ import org .reactivestreams .Subscription ;
20
+ import org .springframework .beans .factory .DisposableBean ;
21
+ import org .springframework .beans .factory .InitializingBean ;
19
22
import org .springframework .http .HttpHeaders ;
20
23
import org .springframework .http .HttpMethod ;
21
24
import org .springframework .http .MediaType ;
44
47
import org .springframework .web .reactive .function .client .ExchangeFilterFunction ;
45
48
import org .springframework .web .reactive .function .client .ExchangeFunction ;
46
49
import org .springframework .web .reactive .function .client .WebClient ;
50
+ import reactor .core .CoreSubscriber ;
51
+ import reactor .core .publisher .Hooks ;
47
52
import reactor .core .publisher .Mono ;
53
+ import reactor .core .publisher .Operators ;
48
54
import reactor .core .scheduler .Schedulers ;
55
+ import reactor .util .context .Context ;
49
56
50
57
import javax .servlet .http .HttpServletRequest ;
51
58
import javax .servlet .http .HttpServletResponse ;
98
105
* @author Rob Winch
99
106
* @since 5.1
100
107
*/
101
- public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
108
+ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
109
+ implements ExchangeFilterFunction , InitializingBean , DisposableBean {
110
+
102
111
/**
103
112
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
104
113
*/
@@ -108,6 +117,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
108
117
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest .class .getName ();
109
118
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse .class .getName ();
110
119
120
+ private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber .class .getName ();
121
+
111
122
private Clock clock = Clock .systemUTC ();
112
123
113
124
private Duration accessTokenExpiresSkew = Duration .ofMinutes (1 );
@@ -123,7 +134,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
123
134
124
135
private String defaultClientRegistrationId ;
125
136
126
- public ServletOAuth2AuthorizedClientExchangeFilterFunction () {}
137
+ public ServletOAuth2AuthorizedClientExchangeFilterFunction () {
138
+ }
127
139
128
140
public ServletOAuth2AuthorizedClientExchangeFilterFunction (
129
141
ClientRegistrationRepository clientRegistrationRepository ,
@@ -132,6 +144,16 @@ public ServletOAuth2AuthorizedClientExchangeFilterFunction(
132
144
this .authorizedClientRepository = authorizedClientRepository ;
133
145
}
134
146
147
+ @ Override
148
+ public void afterPropertiesSet () throws Exception {
149
+ Hooks .onLastOperator (REQUEST_CONTEXT_OPERATOR_KEY , Operators .lift ((s , sub ) -> createRequestContextSubscriber (sub )));
150
+ }
151
+
152
+ @ Override
153
+ public void destroy () throws Exception {
154
+ Hooks .resetOnLastOperator (REQUEST_CONTEXT_OPERATOR_KEY );
155
+ }
156
+
135
157
/**
136
158
* Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
137
159
* client_credentials grant.
@@ -266,15 +288,36 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
266
288
267
289
@ Override
268
290
public Mono <ClientResponse > filter (ClientRequest request , ExchangeFunction next ) {
269
- Optional <OAuth2AuthorizedClient > attribute = request .attribute (OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME )
270
- .map (OAuth2AuthorizedClient .class ::cast );
271
- return Mono .justOrEmpty (attribute )
272
- .flatMap (authorizedClient -> authorizedClient (request , next , authorizedClient ))
291
+ return Mono .just (request )
292
+ .filter (req -> req .attribute (OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME ).isPresent ())
293
+ .switchIfEmpty (mergeRequestAttributesFromContext (request ))
294
+ .filter (req -> req .attribute (OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME ).isPresent ())
295
+ .flatMap (req -> authorizedClient (req , next , getOAuth2AuthorizedClient (req .attributes ())))
273
296
.map (authorizedClient -> bearer (request , authorizedClient ))
274
297
.flatMap (next ::exchange )
275
298
.switchIfEmpty (next .exchange (request ));
276
299
}
277
300
301
+ private Mono <ClientRequest > mergeRequestAttributesFromContext (ClientRequest request ) {
302
+ return Mono .just (ClientRequest .from (request ))
303
+ .flatMap (builder -> Mono .subscriberContext ()
304
+ .map (ctx -> builder .attributes (attrs -> populateRequestAttributes (attrs , ctx ))))
305
+ .map (ClientRequest .Builder ::build );
306
+ }
307
+
308
+ private void populateRequestAttributes (Map <String , Object > attrs , Context ctx ) {
309
+ if (ctx .hasKey (HTTP_SERVLET_REQUEST_ATTR_NAME )) {
310
+ attrs .putIfAbsent (HTTP_SERVLET_REQUEST_ATTR_NAME , ctx .get (HTTP_SERVLET_REQUEST_ATTR_NAME ));
311
+ }
312
+ if (ctx .hasKey (HTTP_SERVLET_RESPONSE_ATTR_NAME )) {
313
+ attrs .putIfAbsent (HTTP_SERVLET_RESPONSE_ATTR_NAME , ctx .get (HTTP_SERVLET_RESPONSE_ATTR_NAME ));
314
+ }
315
+ if (ctx .hasKey (AUTHENTICATION_ATTR_NAME )) {
316
+ attrs .putIfAbsent (AUTHENTICATION_ATTR_NAME , ctx .get (AUTHENTICATION_ATTR_NAME ));
317
+ }
318
+ populateDefaultOAuth2AuthorizedClient (attrs );
319
+ }
320
+
278
321
private void populateDefaultRequestResponse (Map <String , Object > attrs ) {
279
322
if (attrs .containsKey (HTTP_SERVLET_REQUEST_ATTR_NAME ) && attrs .containsKey (
280
323
HTTP_SERVLET_RESPONSE_ATTR_NAME )) {
@@ -435,6 +478,19 @@ private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient autho
435
478
.build ();
436
479
}
437
480
481
+ private <T > CoreSubscriber <T > createRequestContextSubscriber (CoreSubscriber <T > delegate ) {
482
+ HttpServletRequest request = null ;
483
+ HttpServletResponse response = null ;
484
+ ServletRequestAttributes requestAttributes =
485
+ (ServletRequestAttributes ) RequestContextHolder .getRequestAttributes ();
486
+ if (requestAttributes != null ) {
487
+ request = requestAttributes .getRequest ();
488
+ response = requestAttributes .getResponse ();
489
+ }
490
+ Authentication authentication = SecurityContextHolder .getContext ().getAuthentication ();
491
+ return new RequestContextSubscriber <>(delegate , request , response , authentication );
492
+ }
493
+
438
494
private static BodyInserters .FormInserter <String > refreshTokenBody (String refreshToken ) {
439
495
return BodyInserters
440
496
.fromFormData ("grant_type" , AuthorizationGrantType .REFRESH_TOKEN .getValue ())
@@ -508,4 +564,55 @@ private UnsupportedOperationException unsupported() {
508
564
return new UnsupportedOperationException ("Not Supported" );
509
565
}
510
566
}
567
+
568
+ private static class RequestContextSubscriber <T > implements CoreSubscriber <T > {
569
+ private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber .class .getName ().concat (".CONTEXT_DEFAULTED_ATTR_NAME" );
570
+ private final CoreSubscriber <T > delegate ;
571
+ private final HttpServletRequest request ;
572
+ private final HttpServletResponse response ;
573
+ private final Authentication authentication ;
574
+
575
+ private RequestContextSubscriber (CoreSubscriber <T > delegate ,
576
+ HttpServletRequest request ,
577
+ HttpServletResponse response ,
578
+ Authentication authentication ) {
579
+ this .delegate = delegate ;
580
+ this .request = request ;
581
+ this .response = response ;
582
+ this .authentication = authentication ;
583
+ }
584
+
585
+ @ Override
586
+ public Context currentContext () {
587
+ Context context = this .delegate .currentContext ();
588
+ if (context .hasKey (CONTEXT_DEFAULTED_ATTR_NAME )) {
589
+ return context ;
590
+ }
591
+ return Context .of (
592
+ CONTEXT_DEFAULTED_ATTR_NAME , Boolean .TRUE ,
593
+ HTTP_SERVLET_REQUEST_ATTR_NAME , this .request ,
594
+ HTTP_SERVLET_RESPONSE_ATTR_NAME , this .response ,
595
+ AUTHENTICATION_ATTR_NAME , this .authentication );
596
+ }
597
+
598
+ @ Override
599
+ public void onSubscribe (Subscription s ) {
600
+ this .delegate .onSubscribe (s );
601
+ }
602
+
603
+ @ Override
604
+ public void onNext (T t ) {
605
+ this .delegate .onNext (t );
606
+ }
607
+
608
+ @ Override
609
+ public void onError (Throwable t ) {
610
+ this .delegate .onError (t );
611
+ }
612
+
613
+ @ Override
614
+ public void onComplete () {
615
+ this .delegate .onComplete ();
616
+ }
617
+ }
511
618
}
0 commit comments