Skip to content

Commit 2bfb303

Browse files
tyrannosaurus-becksLauren Voswinkel
authored andcommitted
Retry on transient failures during AWS IAM auth login attempts (#8727)
* use retryer for failed aws auth attempts * fixes from testing
1 parent 8a3cbf0 commit 2bfb303

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

builtin/credential/aws/path_login.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ import (
1616
"time"
1717

1818
"github.com/aws/aws-sdk-go/aws"
19+
awsClient "github.com/aws/aws-sdk-go/aws/client"
1920
"github.com/aws/aws-sdk-go/service/ec2"
2021
"github.com/aws/aws-sdk-go/service/iam"
2122
"github.com/fullsailor/pkcs7"
2223
"github.com/hashicorp/errwrap"
2324
cleanhttp "github.com/hashicorp/go-cleanhttp"
25+
"github.com/hashicorp/go-retryablehttp"
2426
uuid "github.com/hashicorp/go-uuid"
2527
"github.com/hashicorp/vault/sdk/framework"
2628
"github.com/hashicorp/vault/sdk/helper/awsutil"
@@ -35,6 +37,10 @@ const (
3537
iamAuthType = "iam"
3638
ec2AuthType = "ec2"
3739
ec2EntityType = "ec2_instance"
40+
41+
// Retry configuration
42+
retryWaitMin = 500 * time.Millisecond
43+
retryWaitMax = 30 * time.Second
3844
)
3945

4046
func (b *backend) pathLogin() *framework.Path {
@@ -1199,6 +1205,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
11991205

12001206
endpoint := "https://sts.amazonaws.com"
12011207

1208+
maxRetries := awsClient.DefaultRetryerMaxNumRetries
12021209
if config != nil {
12031210
if config.IAMServerIdHeaderValue != "" {
12041211
err = validateVaultHeaderValue(headers, parsedUrl, config.IAMServerIdHeaderValue)
@@ -1209,9 +1216,12 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
12091216
if config.STSEndpoint != "" {
12101217
endpoint = config.STSEndpoint
12111218
}
1219+
if config.MaxRetries >= 0 {
1220+
maxRetries = config.MaxRetries
1221+
}
12121222
}
12131223

1214-
callerID, err := submitCallerIdentityRequest(method, endpoint, parsedUrl, body, headers)
1224+
callerID, err := submitCallerIdentityRequest(ctx, maxRetries, method, endpoint, parsedUrl, body, headers)
12151225
if err != nil {
12161226
return logical.ErrorResponse(fmt.Sprintf("error making upstream request: %v", err)), nil
12171227
}
@@ -1555,18 +1565,31 @@ func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse,
15551565
return result, err
15561566
}
15571567

1558-
func submitCallerIdentityRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) {
1568+
func submitCallerIdentityRequest(ctx context.Context, maxRetries int, method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) {
15591569
// NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy
15601570
// The protection against this is that this method will only call the endpoint specified in the
15611571
// client config (defaulting to sts.amazonaws.com), so it would require a Vault admin to override
15621572
// the endpoint to talk to alternate web addresses
15631573
request := buildHttpRequest(method, endpoint, parsedUrl, body, headers)
1574+
retryableReq, err := retryablehttp.FromRequest(request)
1575+
if err != nil {
1576+
return nil, err
1577+
}
1578+
retryableReq = retryableReq.WithContext(ctx)
15641579
client := cleanhttp.DefaultClient()
15651580
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
15661581
return http.ErrUseLastResponse
15671582
}
1583+
retryingClient := &retryablehttp.Client{
1584+
HTTPClient: client,
1585+
RetryWaitMin: retryWaitMin,
1586+
RetryWaitMax: retryWaitMax,
1587+
RetryMax: maxRetries,
1588+
CheckRetry: retryablehttp.DefaultRetryPolicy,
1589+
Backoff: retryablehttp.DefaultBackoff,
1590+
}
15681591

1569-
response, err := client.Do(request)
1592+
response, err := retryingClient.Do(retryableReq)
15701593
if err != nil {
15711594
return nil, errwrap.Wrapf("error making request: {{err}}", err)
15721595
}

0 commit comments

Comments
 (0)