Skip to content

Commit 16fe1c5

Browse files
committed
Expose RestOperations in NimbusJwtDecoderJwkSupport
Fixes gh-5603
1 parent 1198403 commit 16fe1c5

File tree

2 files changed

+102
-49
lines changed

2 files changed

+102
-49
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,42 @@
1515
*/
1616
package org.springframework.security.oauth2.jwt;
1717

18-
import java.net.MalformedURLException;
19-
import java.net.URL;
20-
import java.text.ParseException;
21-
import java.time.Instant;
22-
import java.util.LinkedHashMap;
23-
import java.util.Map;
24-
2518
import com.nimbusds.jose.JWSAlgorithm;
2619
import com.nimbusds.jose.RemoteKeySourceException;
2720
import com.nimbusds.jose.jwk.source.JWKSource;
2821
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
2922
import com.nimbusds.jose.proc.JWSKeySelector;
3023
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
3124
import com.nimbusds.jose.proc.SecurityContext;
32-
import com.nimbusds.jose.util.DefaultResourceRetriever;
25+
import com.nimbusds.jose.util.Resource;
3326
import com.nimbusds.jose.util.ResourceRetriever;
3427
import com.nimbusds.jwt.JWT;
3528
import com.nimbusds.jwt.JWTClaimsSet;
3629
import com.nimbusds.jwt.JWTParser;
3730
import com.nimbusds.jwt.SignedJWT;
3831
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
3932
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
40-
33+
import org.springframework.http.HttpHeaders;
34+
import org.springframework.http.HttpMethod;
35+
import org.springframework.http.MediaType;
36+
import org.springframework.http.RequestEntity;
37+
import org.springframework.http.ResponseEntity;
4138
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
4239
import org.springframework.util.Assert;
40+
import org.springframework.web.client.RestOperations;
41+
import org.springframework.web.client.RestTemplate;
42+
43+
import java.io.IOException;
44+
import java.net.MalformedURLException;
45+
import java.net.URL;
46+
import java.text.ParseException;
47+
import java.time.Instant;
48+
import java.util.Collections;
49+
import java.util.LinkedHashMap;
50+
import java.util.Map;
4351

4452
/**
45-
* An implementation of a {@link JwtDecoder} that "decodes" a
53+
* An implementation of a {@link JwtDecoder} that "decodes" a
4654
* JSON Web Token (JWT) and additionally verifies it's digital signature if the JWT is a
4755
* JSON Web Signature (JWS). The public key used for verification is obtained from the
4856
* JSON Web Key (JWK) Set {@code URL} supplied via the constructor.
@@ -63,9 +71,9 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
6371
private static final String DECODING_ERROR_MESSAGE_TEMPLATE =
6472
"An error occurred while attempting to decode the Jwt: %s";
6573

66-
private final URL jwkSetUrl;
6774
private final JWSAlgorithm jwsAlgorithm;
6875
private final ConfigurableJWTProcessor<SecurityContext> jwtProcessor;
76+
private final RestOperationsResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever();
6977

7078
/**
7179
* Constructs a {@code NimbusJwtDecoderJwkSupport} using the provided parameters.
@@ -85,29 +93,25 @@ public NimbusJwtDecoderJwkSupport(String jwkSetUrl) {
8593
public NimbusJwtDecoderJwkSupport(String jwkSetUrl, String jwsAlgorithm) {
8694
Assert.hasText(jwkSetUrl, "jwkSetUrl cannot be empty");
8795
Assert.hasText(jwsAlgorithm, "jwsAlgorithm cannot be empty");
96+
JWKSource jwkSource;
8897
try {
89-
this.jwkSetUrl = new URL(jwkSetUrl);
98+
jwkSource = new RemoteJWKSet(new URL(jwkSetUrl), this.jwkSetRetriever);
9099
} catch (MalformedURLException ex) {
91-
throw new IllegalArgumentException("Invalid JWK Set URL " + jwkSetUrl + " : " + ex.getMessage(), ex);
100+
throw new IllegalArgumentException("Invalid JWK Set URL \"" + jwkSetUrl + "\" : " + ex.getMessage(), ex);
92101
}
93102
this.jwsAlgorithm = JWSAlgorithm.parse(jwsAlgorithm);
94-
95-
ResourceRetriever jwkSetRetriever = new DefaultResourceRetriever(30000, 30000);
96-
JWKSource jwkSource = new RemoteJWKSet(this.jwkSetUrl, jwkSetRetriever);
97103
JWSKeySelector<SecurityContext> jwsKeySelector =
98104
new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
99-
100105
this.jwtProcessor = new DefaultJWTProcessor<>();
101106
this.jwtProcessor.setJWSKeySelector(jwsKeySelector);
102107
}
103108

104109
@Override
105110
public Jwt decode(String token) throws JwtException {
106111
JWT jwt = this.parse(token);
107-
if ( jwt instanceof SignedJWT ) {
112+
if (jwt instanceof SignedJWT) {
108113
return this.createJwt(token, jwt);
109114
}
110-
111115
throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
112116
}
113117

@@ -158,4 +162,39 @@ private Jwt createJwt(String token, JWT parsedJwt) {
158162

159163
return jwt;
160164
}
165+
166+
/**
167+
* Sets the {@link RestOperations} used when requesting the JSON Web Key (JWK) Set.
168+
*
169+
* @since 5.1
170+
* @param restOperations the {@link RestOperations} used when requesting the JSON Web Key (JWK) Set
171+
*/
172+
public final void setRestOperations(RestOperations restOperations) {
173+
Assert.notNull(restOperations, "restOperations cannot be null");
174+
this.jwkSetRetriever.restOperations = restOperations;
175+
}
176+
177+
private static class RestOperationsResourceRetriever implements ResourceRetriever {
178+
private RestOperations restOperations = new RestTemplate();
179+
180+
@Override
181+
public Resource retrieveResource(URL url) throws IOException {
182+
HttpHeaders headers = new HttpHeaders();
183+
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));
184+
185+
ResponseEntity<String> response;
186+
try {
187+
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI());
188+
response = this.restOperations.exchange(request, String.class);
189+
} catch (Exception ex) {
190+
throw new IOException(ex);
191+
}
192+
193+
if (response.getStatusCodeValue() != 200) {
194+
throw new IOException(response.toString());
195+
}
196+
197+
return new Resource(response.getBody(), "UTF-8");
198+
}
199+
}
161200
}

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2017 the original author or authors.
2+
* Copyright 2002-2018 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -24,23 +24,22 @@
2424
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
2525
import okhttp3.mockwebserver.MockResponse;
2626
import okhttp3.mockwebserver.MockWebServer;
27+
import org.assertj.core.api.Assertions;
2728
import org.junit.Test;
2829
import org.junit.runner.RunWith;
2930
import org.powermock.core.classloader.annotations.PowerMockIgnore;
3031
import org.powermock.core.classloader.annotations.PrepareForTest;
3132
import org.powermock.modules.junit4.PowerMockRunner;
32-
33+
import org.springframework.http.RequestEntity;
3334
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
35+
import org.springframework.web.client.RestTemplate;
3436

3537
import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode;
3638
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
37-
import static org.mockito.ArgumentMatchers.any;
38-
import static org.mockito.ArgumentMatchers.anyString;
39-
import static org.mockito.ArgumentMatchers.eq;
39+
import static org.mockito.ArgumentMatchers.*;
4040
import static org.mockito.Mockito.mock;
41-
import static org.powermock.api.mockito.PowerMockito.mockStatic;
42-
import static org.powermock.api.mockito.PowerMockito.when;
43-
import static org.powermock.api.mockito.PowerMockito.whenNew;
41+
import static org.mockito.Mockito.verify;
42+
import static org.powermock.api.mockito.PowerMockito.*;
4443

4544
/**
4645
* Tests for {@link NimbusJwtDecoderJwkSupport}.
@@ -62,6 +61,8 @@ public class NimbusJwtDecoderJwkSupportTests {
6261
private static final String MALFORMED_JWT = "eyJhbGciOiJSUzI1NiJ9.eyJuYmYiOnt9LCJleHAiOjQ2ODQyMjUwODd9.guoQvujdWvd3xw7FYQEn4D6-gzM_WqFvXdmvAUNSLbxG7fv2_LLCNujPdrBHJoYPbOwS1BGNxIKQWS1tylvqzmr1RohQ-RZ2iAM1HYQzboUlkoMkcd8ENM__ELqho8aNYBfqwkNdUOyBFoy7Syu_w2SoJADw2RTjnesKO6CVVa05bW118pDS4xWxqC4s7fnBjmZoTn4uQ-Kt9YSQZQk8YQxkJSiyanozzgyfgXULA6mPu1pTNU3FVFaK1i1av_xtH_zAPgb647ZeaNe4nahgqC5h8nhOlm8W2dndXbwAt29nd2ZWBsru_QwZz83XSKLhTPFz-mPBByZZDsyBbIHf9A";
6362
private static final String UNSIGNED_JWT = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJleHAiOi0yMDMzMjI0OTcsImp0aSI6IjEyMyIsInR5cCI6IkpXVCJ9.";
6463

64+
private NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM);
65+
6566
@Test
6667
public void constructorWhenJwkSetUrlIsNullThenThrowIllegalArgumentException() {
6768
assertThatThrownBy(() -> new NimbusJwtDecoderJwkSupport(null))
@@ -80,10 +81,15 @@ public void constructorWhenJwsAlgorithmIsNullThenThrowIllegalArgumentException()
8081
.isInstanceOf(IllegalArgumentException.class);
8182
}
8283

84+
@Test
85+
public void setRestOperationsWhenNullThenThrowIllegalArgumentException() {
86+
Assertions.assertThatThrownBy(() -> this.jwtDecoder.setRestOperations(null))
87+
.isInstanceOf(IllegalArgumentException.class);
88+
}
89+
8390
@Test
8491
public void decodeWhenJwtInvalidThenThrowJwtException() {
85-
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM);
86-
assertThatThrownBy(() -> jwtDecoder.decode("invalid"))
92+
assertThatThrownBy(() -> this.jwtDecoder.decode("invalid"))
8793
.isInstanceOf(JwtException.class);
8894
}
8995

@@ -103,16 +109,14 @@ public void decodeWhenExpClaimNullThenDoesNotThrowException() throws Exception {
103109
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().audience("resource1").build();
104110
when(jwtProcessor.process(any(JWT.class), eq(null))).thenReturn(jwtClaimsSet);
105111

106-
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM);
112+
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL);
107113
assertThatCode(() -> jwtDecoder.decode("encoded-jwt")).doesNotThrowAnyException();
108114
}
109115

110116
// gh-5457
111117
@Test
112-
public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() throws Exception {
113-
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM);
114-
115-
assertThatCode(() -> jwtDecoder.decode(UNSIGNED_JWT))
118+
public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() {
119+
assertThatCode(() -> this.jwtDecoder.decode(UNSIGNED_JWT))
116120
.isInstanceOf(JwtException.class)
117121
.hasMessageContaining("Unsupported algorithm of none");
118122
}
@@ -122,12 +126,11 @@ public void decodeWhenJwtIsMalformedThenReturnsStockException() throws Exception
122126
try ( MockWebServer server = new MockWebServer() ) {
123127
server.enqueue(new MockResponse().setBody(JWK_SET));
124128
String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
125-
126-
NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
127-
128-
assertThatCode(() -> decoder.decode(MALFORMED_JWT))
129+
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
130+
assertThatCode(() -> jwtDecoder.decode(MALFORMED_JWT))
129131
.isInstanceOf(JwtException.class)
130132
.hasMessage("An error occurred while attempting to decode the Jwt: Malformed payload");
133+
server.shutdown();
131134
}
132135
}
133136

@@ -136,28 +139,39 @@ public void decodeWhenJwkResponseIsMalformedThenReturnsStockException() throws E
136139
try ( MockWebServer server = new MockWebServer() ) {
137140
server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET));
138141
String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
139-
140-
NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
141-
142-
assertThatCode(() -> decoder.decode(SIGNED_JWT))
142+
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
143+
assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
143144
.isInstanceOf(JwtException.class)
144145
.hasMessage("An error occurred while attempting to decode the Jwt: Malformed Jwk set");
146+
server.shutdown();
145147
}
146148
}
147149

148150
@Test
149-
public void decodeWhenJwkEndpointIsUnresponsiveThenRetrunsJwtException() throws Exception {
151+
public void decodeWhenJwkEndpointIsUnresponsiveThenReturnsJwtException() throws Exception {
150152
try ( MockWebServer server = new MockWebServer() ) {
151153
server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET));
152154
String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
153-
154-
NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
155-
156-
server.shutdown();
157-
158-
assertThatCode(() -> decoder.decode(SIGNED_JWT))
155+
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
156+
assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
159157
.isInstanceOf(JwtException.class)
160158
.hasMessageContaining("An error occurred while attempting to decode the Jwt");
159+
server.shutdown();
160+
}
161+
}
162+
163+
// gh-5603
164+
@Test
165+
public void decodeWhenCustomRestOperationsSetThenUsed() throws Exception {
166+
try ( MockWebServer server = new MockWebServer() ) {
167+
server.enqueue(new MockResponse().setBody(JWK_SET));
168+
String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
169+
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
170+
RestTemplate restTemplate = spy(new RestTemplate());
171+
jwtDecoder.setRestOperations(restTemplate);
172+
assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)).doesNotThrowAnyException();
173+
verify(restTemplate).exchange(any(RequestEntity.class), eq(String.class));
174+
server.shutdown();
161175
}
162176
}
163177
}

0 commit comments

Comments
 (0)