Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions source/extensions/common/aws/credential_provider_chains.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,27 @@ CommonCredentialsProviderChain::customCredentialsProviderChain(
"Custom credential provider chain must have at least one credential provider");
}

return std::make_shared<CommonCredentialsProviderChain>(context, region,
credential_provider_config);
auto chain =
std::make_shared<CommonCredentialsProviderChain>(context, region, credential_provider_config);
chain->setupSubscriptions();
return chain;
}
CredentialsProviderChainSharedPtr CommonCredentialsProviderChain::defaultCredentialsProviderChain(
Server::Configuration::ServerFactoryContext& context, absl::string_view region) {
return std::make_shared<CommonCredentialsProviderChain>(context, region, absl::nullopt);
auto chain = std::make_shared<CommonCredentialsProviderChain>(context, region, absl::nullopt);
chain->setupSubscriptions();
return chain;
}

void CommonCredentialsProviderChain::setupSubscriptions() {
for (auto& provider : providers_) {
// Set up subscription for each provider that supports it
auto metadata_provider = std::dynamic_pointer_cast<MetadataCredentialsProviderBase>(provider);
if (metadata_provider) {
storeSubscription(metadata_provider->subscribeToCredentialUpdates(
std::static_pointer_cast<CredentialSubscriberCallbacks>(shared_from_this())));
}
}
}

CommonCredentialsProviderChain::CommonCredentialsProviderChain(
Expand Down Expand Up @@ -267,8 +282,7 @@ CredentialsProviderSharedPtr CommonCredentialsProviderChain::createAssumeRoleCre
credential_provider->setClusterReadyCallbackHandle(std::move(handleOr.value()));
}

storeSubscription(credential_provider->subscribeToCredentialUpdates(*this));

// Note: Subscription will be set up after construction
return credential_provider;
};

Expand Down Expand Up @@ -303,8 +317,7 @@ CredentialsProviderSharedPtr CommonCredentialsProviderChain::createContainerCred
credential_provider->setClusterReadyCallbackHandle(std::move(handleOr.value()));
}

storeSubscription(credential_provider->subscribeToCredentialUpdates(*this));

// Note: Subscription will be set up after construction
return credential_provider;
}

Expand Down Expand Up @@ -337,8 +350,7 @@ CommonCredentialsProviderChain::createInstanceProfileCredentialsProvider(
credential_provider->setClusterReadyCallbackHandle(std::move(handleOr.value()));
}

storeSubscription(credential_provider->subscribeToCredentialUpdates(*this));

// Note: Subscription will be set up after construction
return credential_provider;
}

Expand Down Expand Up @@ -369,8 +381,7 @@ CredentialsProviderSharedPtr CommonCredentialsProviderChain::createWebIdentityCr
credential_provider->setClusterReadyCallbackHandle(std::move(handleOr.value()));
}

storeSubscription(credential_provider->subscribeToCredentialUpdates(*this));

// Note: Subscription will be set up after construction
return credential_provider;
};

Expand Down
4 changes: 4 additions & 0 deletions source/extensions/common/aws/credential_provider_chains.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ class CommonCredentialsProviderChain : public CredentialsProviderChain,
defaultCredentialsProviderChain(Server::Configuration::ServerFactoryContext& context,
absl::string_view region);

// Where credential providers use async functionality, subscribe to credential notifications for
// these providers
void setupSubscriptions();

private:
CredentialsProviderSharedPtr createEnvironmentCredentialsProvider() const override {
return std::make_shared<EnvironmentCredentialsProvider>();
Expand Down
25 changes: 17 additions & 8 deletions source/extensions/common/aws/credentials_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,28 @@ class CredentialSubscriberCallbacks {
virtual void onCredentialUpdate() PURE;
};

using CredentialSubscriberCallbacksSharedPtr = std::shared_ptr<CredentialSubscriberCallbacks>;

// Subscription model allowing CredentialsProviderChains to be notified of credential provider
// updates. A credential provider chain will call credential_provider->subscribeToCredentialUpdates
// to register itself for updates via onCredentialUpdate callback. When a credential provider has
// successfully updated all threads with new credentials, via the setCredentialsToAllThreads method
// it will notify all subscribers that credentials have been retrieved.
//
// Subscription is only relevant for metadata credentials providers, as these are the only
// credential providers that implement async credential retrieval functionality.
//
// RAII is used, as credential providers may be instantiated as singletons, as such they may outlive
// the credential provider chain. Subscription is only relevant for metadata credentials providers,
// as these are the only credential providers that implement async credential retrieval
// functionality.
class CredentialSubscriberCallbacksHandle : public RaiiListElement<CredentialSubscriberCallbacks*> {
// the credential provider chain.
//
// Uses weak_ptr to safely handle subscriber lifetime without dangling pointers.
class CredentialSubscriberCallbacksHandle
: public RaiiListElement<std::weak_ptr<CredentialSubscriberCallbacks>> {
public:
CredentialSubscriberCallbacksHandle(CredentialSubscriberCallbacks& cb,
std::list<CredentialSubscriberCallbacks*>& parent)
: RaiiListElement<CredentialSubscriberCallbacks*>(parent, &cb) {}
CredentialSubscriberCallbacksHandle(
CredentialSubscriberCallbacksSharedPtr cb,
std::list<std::weak_ptr<CredentialSubscriberCallbacks>>& parent)
: RaiiListElement<std::weak_ptr<CredentialSubscriberCallbacks>>(parent, cb) {}
};

using CredentialSubscriberCallbacksHandlePtr = std::unique_ptr<CredentialSubscriberCallbacksHandle>;
Expand All @@ -187,7 +195,8 @@ using CredentialSubscriberCallbacksHandlePtr = std::unique_ptr<CredentialSubscri
* AWS credentials provider chain, able to fallback between multiple credential providers.
*/
class CredentialsProviderChain : public CredentialSubscriberCallbacks,
public Logger::Loggable<Logger::Id::aws> {
public Logger::Loggable<Logger::Id::aws>,
public std::enable_shared_from_this<CredentialsProviderChain> {
public:
~CredentialsProviderChain() override {
for (auto& subscriber_handle : subscriber_handles_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,17 @@ MetadataCredentialsProviderBase::MetadataCredentialsProviderBase(
void MetadataCredentialsProviderBase::onClusterAddOrUpdate() {
ENVOY_LOG(debug, "Received callback from aws cluster manager for cluster {}", cluster_name_);
if (!cache_duration_timer_) {
cache_duration_timer_ = context_.mainThreadDispatcher().createTimer([this]() -> void {
stats_->credential_refreshes_performed_.inc();
refresh();
});
std::weak_ptr<MetadataCredentialsProviderStats> weak_stats = stats_;
std::weak_ptr<MetadataCredentialsProviderBase> weak_self = shared_from_this();
cache_duration_timer_ =
context_.mainThreadDispatcher().createTimer([weak_stats, weak_self]() -> void {
if (auto stats = weak_stats.lock()) {
stats->credential_refreshes_performed_.inc();
}
if (auto self = weak_self.lock()) {
self->refresh();
}
});
}
if (!cache_duration_timer_->enabled()) {
cache_duration_timer_->enableTimer(std::chrono::milliseconds(1));
Expand Down Expand Up @@ -116,21 +123,24 @@ void MetadataCredentialsProviderBase::setCredentialsToAllThreads(
/* Notify waiting signers on completion of credential setting above */
[this]() {
credentials_pending_.store(false);
std::list<CredentialSubscriberCallbacks*> subscribers_copy;
std::list<std::weak_ptr<CredentialSubscriberCallbacks>> subscribers_copy;
{
Thread::LockGuard guard(mu_);
subscribers_copy = credentials_subscribers_;
}
for (auto& cb : subscribers_copy) {
ENVOY_LOG(debug, "Notifying subscriber of credential update");
cb->onCredentialUpdate();
for (auto& weak_cb : subscribers_copy) {
if (auto cb = weak_cb.lock()) {
ENVOY_LOG(debug, "Notifying subscriber of credential update");
cb->onCredentialUpdate();
}
}
});
}
}

CredentialSubscriberCallbacksHandlePtr
MetadataCredentialsProviderBase::subscribeToCredentialUpdates(CredentialSubscriberCallbacks& cs) {
MetadataCredentialsProviderBase::subscribeToCredentialUpdates(
CredentialSubscriberCallbacksSharedPtr cs) {
Thread::LockGuard guard(mu_);
return std::make_unique<CredentialSubscriberCallbacksHandle>(cs, credentials_subscribers_);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ struct MetadataCredentialsProviderStats {
using CreateMetadataFetcherCb =
std::function<MetadataFetcherPtr(Upstream::ClusterManager&, absl::string_view)>;

class MetadataCredentialsProviderBase : public CredentialsProvider,
public Logger::Loggable<Logger::Id::aws>,
public AwsManagedClusterUpdateCallbacks {
class MetadataCredentialsProviderBase
: public CredentialsProvider,
public Logger::Loggable<Logger::Id::aws>,
public AwsManagedClusterUpdateCallbacks,
public std::enable_shared_from_this<MetadataCredentialsProviderBase> {
public:
friend class MetadataCredentialsProviderBaseFriend;
using OnAsyncFetchCb = std::function<void(const std::string&&)>;
Expand All @@ -53,7 +55,7 @@ class MetadataCredentialsProviderBase : public CredentialsProvider,
}

CredentialSubscriberCallbacksHandlePtr
subscribeToCredentialUpdates(CredentialSubscriberCallbacks& cs);
subscribeToCredentialUpdates(CredentialSubscriberCallbacksSharedPtr cs);

protected:
struct ThreadLocalCredentialsCache : public ThreadLocal::ThreadLocalObject {
Expand Down Expand Up @@ -120,7 +122,8 @@ class MetadataCredentialsProviderBase : public CredentialsProvider,
// Are credentials pending?
std::atomic<bool> credentials_pending_ = true;
Thread::MutexBasicLockable mu_;
std::list<CredentialSubscriberCallbacks*> credentials_subscribers_ ABSL_GUARDED_BY(mu_);
std::list<std::weak_ptr<CredentialSubscriberCallbacks>>
credentials_subscribers_ ABSL_GUARDED_BY(mu_);
};

} // namespace Aws
Expand Down
138 changes: 134 additions & 4 deletions test/extensions/common/aws/credentials_provider_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "source/extensions/common/aws/signers/sigv4_signer_impl.h"

#include "test/extensions/common/aws/mocks.h"
#include "test/mocks/event/mocks.h"
#include "test/mocks/server/server_factory_context.h"

#include "gtest/gtest.h"
Expand Down Expand Up @@ -81,7 +82,7 @@ class AsyncCredentialHandlingTest : public testing::Test {
MetadataFetcherPtr metadata_fetcher_;
NiceMock<Server::Configuration::MockServerFactoryContext> context_;
WebIdentityCredentialsProviderPtr provider_;
Event::MockTimer* timer_{};
Event::MockTimer* timer_;
NiceMock<Upstream::MockClusterManager> cm_;
std::shared_ptr<MockAwsClusterManager> mock_manager_;
Http::RequestMessagePtr message_;
Expand Down Expand Up @@ -180,7 +181,7 @@ TEST_F(AsyncCredentialHandlingTest, ChainCallbackCalledWhenCredentialsReturned)
}
)EOF";

auto handle = provider_->subscribeToCredentialUpdates(*chain);
auto handle = provider_->subscribeToCredentialUpdates(chain);

auto signer = std::make_unique<Extensions::Common::Aws::SigV4SignerImpl>(
"vpc-lattice-svcs", "ap-southeast-2", chain, context_,
Expand Down Expand Up @@ -249,8 +250,8 @@ TEST_F(AsyncCredentialHandlingTest, SubscriptionsCleanedUp) {
}
)EOF";

auto handle = provider_->subscribeToCredentialUpdates(*chain);
auto handle2 = provider_->subscribeToCredentialUpdates(*chain);
auto handle = provider_->subscribeToCredentialUpdates(chain);
auto handle2 = provider_->subscribeToCredentialUpdates(chain);

auto signer = std::make_unique<Extensions::Common::Aws::SigV4SignerImpl>(
"vpc-lattice-svcs", "ap-southeast-2", chain, context_,
Expand All @@ -275,6 +276,135 @@ TEST_F(AsyncCredentialHandlingTest, SubscriptionsCleanedUp) {
ASSERT_TRUE(result.ok());
}

// Mock WebIdentityCredentialsProvider to track refresh calls
class MockWebIdentityProvider : public WebIdentityCredentialsProvider {
public:
MockWebIdentityProvider(
Server::Configuration::ServerFactoryContext& context,
AwsClusterManagerPtr aws_cluster_manager, absl::string_view cluster_name,
CreateMetadataFetcherCb create_metadata_fetcher_cb,
MetadataFetcher::MetadataReceiver::RefreshState refresh_state,
std::chrono::seconds initialization_timer,
const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& config)
: WebIdentityCredentialsProvider(context, aws_cluster_manager, cluster_name,
create_metadata_fetcher_cb, refresh_state,
initialization_timer, config) {}
MOCK_METHOD(void, refresh, (), (override));
};

TEST_F(AsyncCredentialHandlingTest, WeakPtrProtectionInTimerCallback) {

MetadataFetcher::MetadataReceiver::RefreshState refresh_state =
MetadataFetcher::MetadataReceiver::RefreshState::Ready;
std::chrono::seconds initialization_timer = std::chrono::seconds(2);

envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider cred_provider =
{};
cred_provider.mutable_web_identity_token_data_source()->set_inline_string("token");
cred_provider.set_role_arn("aws:iam::123456789012:role/arn");
cred_provider.set_role_session_name("session");

mock_manager_ = std::make_shared<MockAwsClusterManager>();
EXPECT_CALL(*mock_manager_, getUriFromClusterName(_)).WillRepeatedly(Return("uri"));

auto mock_provider = std::make_shared<MockWebIdentityProvider>(
context_, mock_manager_, "cluster",
[this](Upstream::ClusterManager&, absl::string_view) {
metadata_fetcher_.reset(raw_metadata_fetcher_);
return std::move(metadata_fetcher_);
},
refresh_state, initialization_timer, cred_provider);

timer_ = new NiceMock<Event::MockTimer>(&context_.dispatcher_);
Event::MockTimer* timer_ptr = timer_; // Keep raw pointer to test after provider destruction
auto provider_friend = MetadataCredentialsProviderBaseFriend(mock_provider);

// When provider is alive, refresh should be called
EXPECT_CALL(*mock_provider, refresh());
provider_friend.onClusterAddOrUpdate();
timer_ptr->enabled_ = true;
timer_ptr->invokeCallback();
delete (raw_metadata_fetcher_);
}

TEST_F(AsyncCredentialHandlingTest, WeakPtrProtectionForStatsInTimerCallback) {
MetadataFetcher::MetadataReceiver::RefreshState refresh_state =
MetadataFetcher::MetadataReceiver::RefreshState::Ready;
std::chrono::seconds initialization_timer = std::chrono::seconds(2);

envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider cred_provider =
{};
cred_provider.mutable_web_identity_token_data_source()->set_inline_string("token");
cred_provider.set_role_arn("aws:iam::123456789012:role/arn");
cred_provider.set_role_session_name("session");

mock_manager_ = std::make_shared<MockAwsClusterManager>();
EXPECT_CALL(*mock_manager_, getUriFromClusterName(_)).WillRepeatedly(Return("uri"));

auto mock_provider = std::make_shared<MockWebIdentityProvider>(
context_, mock_manager_, "cluster",
[this](Upstream::ClusterManager&, absl::string_view) {
metadata_fetcher_.reset(raw_metadata_fetcher_);
return std::move(metadata_fetcher_);
},
refresh_state, initialization_timer, cred_provider);

timer_ = new NiceMock<Event::MockTimer>(&context_.dispatcher_);
Event::MockTimer* timer_ptr = timer_;
auto provider_friend = MetadataCredentialsProviderBaseFriend(mock_provider);
provider_friend.onClusterAddOrUpdate();

// Invalidate stats pointer
provider_friend.invalidateStats();

// Timer callback will skip the stats call due to weak_ptr lock failing
EXPECT_CALL(*mock_provider, refresh());
timer_ptr->enabled_ = true;
timer_ptr->invokeCallback();
delete (raw_metadata_fetcher_);
}

TEST_F(AsyncCredentialHandlingTest, WeakPtrProtectionInSubscriberCallback) {
MetadataFetcher::MetadataReceiver::RefreshState refresh_state =
MetadataFetcher::MetadataReceiver::RefreshState::Ready;
std::chrono::seconds initialization_timer = std::chrono::seconds(2);

envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider cred_provider =
{};
cred_provider.mutable_web_identity_token_data_source()->set_inline_string("token");
cred_provider.set_role_arn("aws:iam::123456789012:role/arn");
cred_provider.set_role_session_name("session");

mock_manager_ = std::make_shared<MockAwsClusterManager>();
EXPECT_CALL(*mock_manager_, getUriFromClusterName(_)).WillRepeatedly(Return("uri"));

provider_ = std::make_shared<WebIdentityCredentialsProvider>(
context_, mock_manager_, "cluster",
[this](Upstream::ClusterManager&, absl::string_view) {
metadata_fetcher_.reset(raw_metadata_fetcher_);
return std::move(metadata_fetcher_);
},
refresh_state, initialization_timer, cred_provider);

auto provider_friend = MetadataCredentialsProviderBaseFriend(provider_);

// Test 1: When subscriber is alive, onCredentialUpdate should be called
auto chain = std::make_shared<MockCredentialsProviderChain>();
EXPECT_CALL(*chain, onCredentialUpdate());
auto handle = provider_->subscribeToCredentialUpdates(chain);

// Trigger credential update
provider_friend.setCredentialsToAllThreads(std::make_unique<Credentials>("key", "secret"));

// Test 2: When subscriber is destroyed, onCredentialUpdate should not be called
EXPECT_CALL(*chain, onCredentialUpdate()).Times(0);
chain.reset(); // Destroy the subscriber

// Trigger credential update - should not crash due to weak_ptr protection
provider_friend.setCredentialsToAllThreads(std::make_unique<Credentials>("key2", "secret2"));
delete (raw_metadata_fetcher_);
}

class ControlledCredentialsProvider : public CredentialsProvider {
public:
ControlledCredentialsProvider(CredentialSubscriberCallbacks* cb) : cb_(cb) {}
Expand Down
5 changes: 5 additions & 0 deletions test/extensions/common/aws/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ class MetadataCredentialsProviderBaseFriend {
provider_->metadata_fetcher_ = std::move(fetcher);
}
void setCacheDurationTimer(Event::Timer* timer) { provider_->cache_duration_timer_.reset(timer); }
void setCredentialsToAllThreads(CredentialsConstUniquePtr&& creds) {
provider_->setCredentialsToAllThreads(std::move(creds));
}
void invalidateStats() { provider_->stats_.reset(); }

std::shared_ptr<MetadataCredentialsProviderBase> provider_;
};

Expand Down