Skip to content

Improve cache invalidation in IdP SP cache #128890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/128890.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 128890
summary: Improve cache invalidation in IdP SP cache
area: IdentityProvider
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;

public class IdentityProviderAuthenticationIT extends IdpRestTestCase {
Expand Down Expand Up @@ -89,6 +90,52 @@ public void testRegistrationAndIdpInitiatedSso() throws Exception {
authenticateWithSamlResponse(samlResponse, null);
}

public void testUpdateExistingServiceProvider() throws Exception {
final Map<String, Object> request1 = Map.ofEntries(
Map.entry("name", "Test SP [v1]"),
Map.entry("acs", SP_ACS),
Map.entry("privileges", Map.ofEntries(Map.entry("resource", SP_ENTITY_ID), Map.entry("roles", List.of("sso:(\\w+)")))),
Map.entry(
"attributes",
Map.ofEntries(
Map.entry("principal", "https://idp.test.es.elasticsearch.org/attribute/principal"),
Map.entry("name", "https://idp.test.es.elasticsearch.org/attribute/name"),
Map.entry("email", "https://idp.test.es.elasticsearch.org/attribute/email"),
Map.entry("roles", "https://idp.test.es.elasticsearch.org/attribute/roles")
)
)
);
final SamlServiceProviderIndex.DocumentVersion docVersion1 = createServiceProvider(SP_ENTITY_ID, request1);
checkIndexDoc(docVersion1);
ensureGreen(SamlServiceProviderIndex.INDEX_NAME);

final String samlResponse1 = generateSamlResponse(SP_ENTITY_ID, SP_ACS, null);
assertThat(samlResponse1, containsString("https://idp.test.es.elasticsearch.org/attribute/principal"));
assertThat(samlResponse1, not(containsString("https://idp.test.es.elasticsearch.org/attribute/username")));

final Map<String, Object> request = Map.ofEntries(
Map.entry("name", "Test SP [v2]"),
Map.entry("acs", SP_ACS),
Map.entry("privileges", Map.ofEntries(Map.entry("resource", SP_ENTITY_ID), Map.entry("roles", List.of("sso:(\\w+)")))),
Map.entry(
"attributes",
Map.ofEntries(
Map.entry("principal", "https://idp.test.es.elasticsearch.org/attribute/username"),
Map.entry("name", "https://idp.test.es.elasticsearch.org/attribute/name"),
Map.entry("email", "https://idp.test.es.elasticsearch.org/attribute/email"),
Map.entry("roles", "https://idp.test.es.elasticsearch.org/attribute/roles")
)
)
);
final SamlServiceProviderIndex.DocumentVersion docVersion2 = createServiceProvider(SP_ENTITY_ID, request);
checkIndexDoc(docVersion2);
ensureGreen(SamlServiceProviderIndex.INDEX_NAME);

final String samlResponse2 = generateSamlResponse(SP_ENTITY_ID, SP_ACS, null);
assertThat(samlResponse2, containsString("https://idp.test.es.elasticsearch.org/attribute/username"));
assertThat(samlResponse2, not(containsString("https://idp.test.es.elasticsearch.org/attribute/principal")));
}

public void testCustomAttributesInIdpInitiatedSso() throws Exception {
final Map<String, Object> request = Map.ofEntries(
Map.entry("name", "Test SP With Custom Attributes"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ public Collection<?> createComponents(PluginServices services) {
index,
serviceProviderFactory
);
services.clusterService().addListener(registeredServiceProviderResolver);

final WildcardServiceProviderResolver wildcardServiceProviderResolver = WildcardServiceProviderResolver.create(
services.environment(),
services.resourceWatcherService(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.metadata.IndexAbstraction;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.util.CachedSupplier;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.get.GetResult;
import org.elasticsearch.index.query.QueryBuilder;
Expand All @@ -51,6 +53,7 @@
import java.util.Arrays;
import java.util.Objects;
import java.util.Set;
import java.util.SortedMap;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -152,6 +155,21 @@ private void checkForAliasStateChange(ClusterState state) {
}
}

Index getIndex(ClusterState state) {
final ProjectMetadata project = state.getMetadata().getProject();
final SortedMap<String, IndexAbstraction> indicesLookup = project.getIndicesLookup();

IndexAbstraction indexAbstraction = indicesLookup.get(ALIAS_NAME);
if (indexAbstraction == null) {
indexAbstraction = indicesLookup.get(INDEX_NAME);
}
if (indexAbstraction == null) {
return null;
} else {
return indexAbstraction.getWriteIndex();
}
}

@Override
public void close() {
logger.debug("Closing ... removing cluster state listener");
Expand Down Expand Up @@ -255,7 +273,12 @@ public void refresh(ActionListener<Void> listener) {

private void findDocuments(QueryBuilder query, ActionListener<Set<DocumentSupplier>> listener) {
logger.trace("Searching [{}] for [{}]", ALIAS_NAME, query);
final SearchRequest request = client.prepareSearch(ALIAS_NAME).setQuery(query).setSize(1000).setFetchSource(true).request();
final SearchRequest request = client.prepareSearch(ALIAS_NAME)
.setQuery(query)
.setSize(1000)
.setFetchSource(true)
.seqNoAndPrimaryTerm(true)
.request();
client.search(request, ActionListener.wrap(response -> {
if (logger.isTraceEnabled()) {
logger.trace("Search hits: [{}] [{}]", response.getHits().getTotalHits(), Arrays.toString(response.getHits().getHits()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@

package org.elasticsearch.xpack.idp.saml.sp;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.iterable.Iterables;
import org.elasticsearch.index.Index;
import org.elasticsearch.xpack.idp.saml.sp.SamlServiceProviderIndex.DocumentSupplier;
import org.elasticsearch.xpack.idp.saml.sp.SamlServiceProviderIndex.DocumentVersion;

import java.util.Objects;
import java.util.stream.Collectors;

public class SamlServiceProviderResolver {
public class SamlServiceProviderResolver implements ClusterStateListener {

private final Cache<String, CachedServiceProvider> cache;
private final SamlServiceProviderIndex index;
Expand All @@ -32,6 +38,8 @@ public SamlServiceProviderResolver(
this.serviceProviderFactory = serviceProviderFactory;
}

private final Logger logger = LogManager.getLogger(getClass());

/**
* Find a {@link SamlServiceProvider} by entity-id.
*
Expand Down Expand Up @@ -75,6 +83,16 @@ private void populateCacheAndReturn(String entityId, DocumentSupplier doc, Actio
listener.onResponse(serviceProvider);
}

@Override
public void clusterChanged(ClusterChangedEvent event) {
final Index previousIndex = index.getIndex(event.previousState());
final Index currentIndex = index.getIndex(event.state());
if (Objects.equals(previousIndex, currentIndex) == false) {
logger.info("Index has changed [{}] => [{}], clearing cache", previousIndex, currentIndex);
this.cache.invalidateAll();
}
}

private class CachedServiceProvider {
private final String entityId;
private final DocumentVersion documentVersion;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.Index;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.idp.saml.idp.SamlIdentityProvider;
import org.elasticsearch.xpack.idp.saml.sp.SamlServiceProviderIndex.DocumentVersion;
Expand Down Expand Up @@ -135,6 +139,37 @@ public void testResolveIgnoresCacheWhenDocumentVersionChanges() throws Exception
assertThat(serviceProvider2.getPrivileges().getResource(), equalTo(document2.privileges.resource));
}

public void testCacheIsClearedWhenIndexChanges() throws Exception {
final SamlServiceProviderDocument document1 = SamlServiceProviderTestUtils.randomDocument(1);
final SamlServiceProviderDocument document2 = SamlServiceProviderTestUtils.randomDocument(2);
document2.entityId = document1.entityId;

final DocumentVersion docVersion = new DocumentVersion(randomAlphaOfLength(12), 1, 1);

mockDocument(document1.entityId, docVersion, document1);
final SamlServiceProvider serviceProvider1a = resolveServiceProvider(document1.entityId);
final SamlServiceProvider serviceProvider1b = resolveServiceProvider(document1.entityId);
assertThat(serviceProvider1b, sameInstance(serviceProvider1a));

final ClusterState oldState = ClusterState.builder(ClusterName.DEFAULT).build();
final ClusterState newState = ClusterState.builder(ClusterName.DEFAULT).build();
when(index.getIndex(oldState)).thenReturn(new Index(SamlServiceProviderIndex.INDEX_NAME, randomUUID()));
when(index.getIndex(newState)).thenReturn(new Index(SamlServiceProviderIndex.INDEX_NAME, randomUUID()));
resolver.clusterChanged(new ClusterChangedEvent(getTestName(), newState, oldState));

mockDocument(document1.entityId, docVersion, document2);
final SamlServiceProvider serviceProvider2 = resolveServiceProvider(document1.entityId);

assertThat(serviceProvider2, not(sameInstance(serviceProvider1a)));
assertThat(serviceProvider2.getEntityId(), equalTo(document2.entityId));
assertThat(serviceProvider2.getAssertionConsumerService().toString(), equalTo(document2.acs));
assertThat(serviceProvider2.getAttributeNames().principal, equalTo(document2.attributeNames.principal));
assertThat(serviceProvider2.getAttributeNames().name, equalTo(document2.attributeNames.name));
assertThat(serviceProvider2.getAttributeNames().email, equalTo(document2.attributeNames.email));
assertThat(serviceProvider2.getAttributeNames().roles, equalTo(document2.attributeNames.roles));
assertThat(serviceProvider2.getPrivileges().getResource(), equalTo(document2.privileges.resource));
}

private SamlServiceProvider resolveServiceProvider(String entityId) {
final PlainActionFuture<SamlServiceProvider> future = new PlainActionFuture<>();
resolver.resolve(entityId, future);
Expand Down
Loading