Skip to content

Provide possibility to use custom cache to store JWK Set #8332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.JWSKeySelector;
Expand All @@ -49,6 +51,7 @@
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTProcessor;

import org.springframework.cache.Cache;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
Expand All @@ -68,6 +71,7 @@
*
* @author Josh Cummings
* @author Joe Grandja
* @author Mykyta Bezverkhyi
* @since 5.2
*/
public final class NimbusJwtDecoder implements JwtDecoder {
Expand Down Expand Up @@ -215,6 +219,7 @@ public static final class JwkSetUriJwtDecoderBuilder {
private String jwkSetUri;
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
private RestOperations restOperations = new RestTemplate();
private Cache cache;

private JwkSetUriJwtDecoderBuilder(String jwkSetUri) {
Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
Expand Down Expand Up @@ -264,6 +269,20 @@ public JwkSetUriJwtDecoderBuilder restOperations(RestOperations restOperations)
return this;
}

/**
* Use the given {@link Cache} to store
* <a href="https://tools.ietf.org/html/rfc7517#section-5">JWK Set</a>.
*
* @param cache the {@link Cache} to be used to store JWK Set
* @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations
* @since 5.4
*/
public JwkSetUriJwtDecoderBuilder cache(Cache cache) {
Assert.notNull(cache, "cache cannot be null");
this.cache = cache;
return this;
}

JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSource) {
if (this.signatureAlgorithms.isEmpty()) {
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
Expand All @@ -280,9 +299,17 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
}
}

JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever) {
if (this.cache == null) {
return new RemoteJWKSet<>(toURL(this.jwkSetUri), jwkSetRetriever);
}
ResourceRetriever cachingJwkSetRetriever = new CachingResourceRetriever(this.cache, jwkSetRetriever);
return new RemoteJWKSet<>(toURL(this.jwkSetUri), cachingJwkSetRetriever, new NoOpJwkSetCache());
}

JWTProcessor<SecurityContext> processor() {
ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations);
JWKSource<SecurityContext> jwkSource = new RemoteJWKSet<>(toURL(this.jwkSetUri), jwkSetRetriever);
JWKSource<SecurityContext> jwkSource = jwkSource(jwkSetRetriever);
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource));

Expand All @@ -309,6 +336,44 @@ private static URL toURL(String url) {
}
}

private static class NoOpJwkSetCache implements JWKSetCache {
@Override
public void put(JWKSet jwkSet) {
}

@Override
public JWKSet get() {
return null;
}

@Override
public boolean requiresRefresh() {
return true;
}
}

private static class CachingResourceRetriever implements ResourceRetriever {
private final Cache cache;
private final ResourceRetriever resourceRetriever;

CachingResourceRetriever(Cache cache, ResourceRetriever resourceRetriever) {
this.cache = cache;
this.resourceRetriever = resourceRetriever;
}

@Override
public Resource retrieveResource(URL url) throws IOException {
String jwkSet;
try {
jwkSet = cache.get(url.toString(), () -> resourceRetriever.retrieveResource(url).getContent());
} catch (Exception ex) {
throw new IOException(ex);
}

return new Resource(jwkSet, "UTF-8");
}
}

private static class RestOperationsResourceRetriever implements ResourceRetriever {
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
private final RestOperations restOperations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import javax.crypto.SecretKey;

import com.nimbusds.jose.JWSAlgorithm;
Expand All @@ -55,6 +56,8 @@
import org.junit.Test;

import org.mockito.ArgumentCaptor;
import org.springframework.cache.Cache;
import org.springframework.cache.concurrent.ConcurrentMapCache;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
Expand All @@ -66,6 +69,7 @@
import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestOperations;

import static org.assertj.core.api.Assertions.assertThat;
Expand All @@ -75,6 +79,8 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withPublicKey;
Expand All @@ -85,6 +91,7 @@
*
* @author Josh Cummings
* @author Joe Grandja
* @author Mykyta Bezverkhyi
*/
public class NimbusJwtDecoderTests {
private static final String JWK_SET = "{\"keys\":[{\"p\":\"49neceJFs8R6n7WamRGy45F5Tv0YM-R2ODK3eSBUSLOSH2tAqjEVKOkLE5fiNA3ygqq15NcKRadB2pTVf-Yb5ZIBuKzko8bzYIkIqYhSh_FAdEEr0vHF5fq_yWSvc6swsOJGqvBEtuqtJY027u-G2gAQasCQdhyejer68zsTn8M\",\"kty\":\"RSA\",\"q\":\"tWR-ysspjZ73B6p2vVRVyHwP3KQWL5KEQcdgcmMOE_P_cPs98vZJfLhxobXVmvzuEWBpRSiqiuyKlQnpstKt94Cy77iO8m8ISfF3C9VyLWXi9HUGAJb99irWABFl3sNDff5K2ODQ8CmuXLYM25OwN3ikbrhEJozlXg_NJFSGD4E\",\"d\":\"FkZHYZlw5KSoqQ1i2RA2kCUygSUOf1OqMt3uomtXuUmqKBm_bY7PCOhmwbvbn4xZYEeHuTR8Xix-0KpHe3NKyWrtRjkq1T_un49_1LLVUhJ0dL-9_x0xRquVjhl_XrsRXaGMEHs8G9pLTvXQ1uST585gxIfmCe0sxPZLvwoic-bXf64UZ9BGRV3lFexWJQqCZp2S21HfoU7wiz6kfLRNi-K4xiVNB1gswm_8o5lRuY7zB9bRARQ3TS2G4eW7p5sxT3CgsGiQD3_wPugU8iDplqAjgJ5ofNJXZezoj0t6JMB_qOpbrmAM1EnomIPebSLW7Ky9SugEd6KMdL5lW6AuAQ\",\"e\":\"AQAB\",\"use\":\"sig\",\"kid\":\"one\",\"qi\":\"wdkFu_tV2V1l_PWUUimG516Zvhqk2SWDw1F7uNDD-Lvrv_WNRIJVzuffZ8WYiPy8VvYQPJUrT2EXL8P0ocqwlaSTuXctrORcbjwgxDQDLsiZE0C23HYzgi0cofbScsJdhcBg7d07LAf7cdJWG0YVl1FkMCsxUlZ2wTwHfKWf-v4\",\"dp\":\"uwnPxqC-IxG4r33-SIT02kZC1IqC4aY7PWq0nePiDEQMQWpjjNH50rlq9EyLzbtdRdIouo-jyQXB01K15-XXJJ60dwrGLYNVqfsTd0eGqD1scYJGHUWG9IDgCsxyEnuG3s0AwbW2UolWVSsU2xMZGb9PurIUZECeD1XDZwMp2s0\",\"dq\":\"hra786AunB8TF35h8PpROzPoE9VJJMuLrc6Esm8eZXMwopf0yhxfN2FEAvUoTpLJu93-UH6DKenCgi16gnQ0_zt1qNNIVoRfg4rw_rjmsxCYHTVL3-RDeC8X_7TsEySxW0EgFTHh-nr6I6CQrAJjPM88T35KHtdFATZ7BCBB8AE\",\"n\":\"oXJ8OyOv_eRnce4akdanR4KYRfnC2zLV4uYNQpcFn6oHL0dj7D6kxQmsXoYgJV8ZVDn71KGmuLvolxsDncc2UrhyMBY6DVQVgMSVYaPCTgW76iYEKGgzTEw5IBRQL9w3SRJWd3VJTZZQjkXef48Ocz06PGF3lhbz4t5UEZtdF4rIe7u-977QwHuh7yRPBQ3sII-cVoOUMgaXB9SHcGF2iZCtPzL_IffDUcfhLQteGebhW8A6eUHgpD5A1PQ-JCw_G7UOzZAjjDjtNM2eqm8j-Ms_gqnm4MiCZ4E-9pDN77CAAPVN7kuX6ejs9KBXpk01z48i9fORYk9u7rAkh1HuQw\"}]}";
Expand Down Expand Up @@ -247,6 +254,21 @@ public void decodeWhenJwkEndpointIsUnresponsiveThenReturnsJwtException() throws
}
}

@Test
public void shouldThrowJwtExceptionWhenJwkSetEndpointHasNotRespondedAndCacheIsConfigured() throws Exception {
try ( MockWebServer server = new MockWebServer() ) {
Cache cache = new ConcurrentMapCache("test-jwk-set-cache");
String jwkSetUri = server.url("/.well-known/jwks.json").toString();
NimbusJwtDecoder jwtDecoder = withJwkSetUri(jwkSetUri).cache(cache).build();

server.shutdown();
assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
.isInstanceOf(JwtException.class)
.isNotInstanceOf(BadJwtException.class)
.hasMessageContaining("An error occurred while attempting to decode the Jwt");
}
}

@Test
public void withJwkSetUriWhenNullOrEmptyThenThrowsException() {
Assertions.assertThatCode(() -> withJwkSetUri(null)).isInstanceOf(IllegalArgumentException.class);
Expand All @@ -264,6 +286,12 @@ public void restOperationsWhenNullThenThrowsException() {
Assertions.assertThatCode(() -> builder.restOperations(null)).isInstanceOf(IllegalArgumentException.class);
}

@Test
public void shouldThrowIllegalArgumentExceptionWhenJwkSetCacheIsNull() {
NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = withJwkSetUri(JWK_SET_URI);
Assertions.assertThatCode(() -> builder.cache(null)).isInstanceOf(IllegalArgumentException.class);
}

@Test
public void withPublicKeyWhenNullThenThrowsException() {
assertThatThrownBy(() -> withPublicKey(null))
Expand Down Expand Up @@ -425,7 +453,7 @@ public void decodeWhenJwkSetRequestedThenAcceptHeaderJsonAndJwkSetJson() {
RestOperations restOperations = mock(RestOperations.class);
when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
.thenReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
JWTProcessor<SecurityContext> processor = withJwkSetUri("https://issuer/.well-known/jwks.json")
JWTProcessor<SecurityContext> processor = withJwkSetUri(JWK_SET_URI)
.restOperations(restOperations)
.processor();
NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(processor);
Expand All @@ -436,6 +464,64 @@ public void decodeWhenJwkSetRequestedThenAcceptHeaderJsonAndJwkSetJson() {
assertThat(acceptHeader).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON);
}

@Test
public void shouldStoreRetrievedJwkSetToCache() {
// given
Cache cache = new ConcurrentMapCache("test-jwk-set-cache");
RestOperations restOperations = mock(RestOperations.class);
when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
.thenReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI)
.restOperations(restOperations)
.cache(cache)
.build();
// when
jwtDecoder.decode(SIGNED_JWT);
// then
assertThat(cache.get(JWK_SET_URI, String.class)).isEqualTo(JWK_SET);
ArgumentCaptor<RequestEntity> requestEntityCaptor = ArgumentCaptor.forClass(RequestEntity.class);
verify(restOperations).exchange(requestEntityCaptor.capture(), eq(String.class));
verifyNoMoreInteractions(restOperations);
List<MediaType> acceptHeader = requestEntityCaptor.getValue().getHeaders().getAccept();
assertThat(acceptHeader).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON);
}

@Test
public void shouldDecodeJwtUsingJwkSetCache() {
// given
RestOperations restOperations = mock(RestOperations.class);
Cache cache = mock(Cache.class);
when(cache.get(eq(JWK_SET_URI), any(Callable.class))).thenReturn(JWK_SET);
NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI)
.cache(cache)
.restOperations(restOperations)
.build();
// when
jwtDecoder.decode(SIGNED_JWT);
// then
verify(cache).get(eq(JWK_SET_URI), any(Callable.class));
verifyNoMoreInteractions(cache);
verifyNoInteractions(restOperations);
}

@Test
public void shouldThrowJwtExceptionWhenExceptionOccurredWhileRetrievingJwkSetInsideCachingRetriever() {
// given
Cache cache = new ConcurrentMapCache("test-jwk-set-cache");
RestOperations restOperations = mock(RestOperations.class);
when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
.thenThrow(new RestClientException("Cannot retrieve JWK Set"));
NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI)
.restOperations(restOperations)
.cache(cache)
.build();
// then
assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
.isInstanceOf(JwtException.class)
.isNotInstanceOf(BadJwtException.class)
.hasMessageContaining("An error occurred while attempting to decode the Jwt");
}

private RSAPublicKey key() throws InvalidKeySpecException {
byte[] decoded = Base64.getDecoder().decode(VERIFY_KEY.getBytes());
EncodedKeySpec spec = new X509EncodedKeySpec(decoded);
Expand Down Expand Up @@ -466,7 +552,7 @@ private static JWTProcessor<SecurityContext> withSigning(String jwkResponse) {
RestOperations restOperations = mock(RestOperations.class);
when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
.thenReturn(new ResponseEntity<>(jwkResponse, HttpStatus.OK));
return withJwkSetUri("https://issuer/.well-known/jwks.json")
return withJwkSetUri(JWK_SET_URI)
.restOperations(restOperations)
.processor();
}
Expand Down