@@ -16,11 +16,13 @@ import (
16
16
"time"
17
17
18
18
"github.com/aws/aws-sdk-go/aws"
19
+ awsClient "github.com/aws/aws-sdk-go/aws/client"
19
20
"github.com/aws/aws-sdk-go/service/ec2"
20
21
"github.com/aws/aws-sdk-go/service/iam"
21
22
"github.com/fullsailor/pkcs7"
22
23
"github.com/hashicorp/errwrap"
23
24
cleanhttp "github.com/hashicorp/go-cleanhttp"
25
+ "github.com/hashicorp/go-retryablehttp"
24
26
uuid "github.com/hashicorp/go-uuid"
25
27
"github.com/hashicorp/vault/sdk/framework"
26
28
"github.com/hashicorp/vault/sdk/helper/awsutil"
@@ -35,6 +37,10 @@ const (
35
37
iamAuthType = "iam"
36
38
ec2AuthType = "ec2"
37
39
ec2EntityType = "ec2_instance"
40
+
41
+ // Retry configuration
42
+ retryWaitMin = 500 * time .Millisecond
43
+ retryWaitMax = 30 * time .Second
38
44
)
39
45
40
46
func (b * backend ) pathLogin () * framework.Path {
@@ -1199,6 +1205,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
1199
1205
1200
1206
endpoint := "https://sts.amazonaws.com"
1201
1207
1208
+ maxRetries := awsClient .DefaultRetryerMaxNumRetries
1202
1209
if config != nil {
1203
1210
if config .IAMServerIdHeaderValue != "" {
1204
1211
err = validateVaultHeaderValue (headers , parsedUrl , config .IAMServerIdHeaderValue )
@@ -1209,9 +1216,12 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
1209
1216
if config .STSEndpoint != "" {
1210
1217
endpoint = config .STSEndpoint
1211
1218
}
1219
+ if config .MaxRetries >= 0 {
1220
+ maxRetries = config .MaxRetries
1221
+ }
1212
1222
}
1213
1223
1214
- callerID , err := submitCallerIdentityRequest (method , endpoint , parsedUrl , body , headers )
1224
+ callerID , err := submitCallerIdentityRequest (ctx , maxRetries , method , endpoint , parsedUrl , body , headers )
1215
1225
if err != nil {
1216
1226
return logical .ErrorResponse (fmt .Sprintf ("error making upstream request: %v" , err )), nil
1217
1227
}
@@ -1555,18 +1565,31 @@ func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse,
1555
1565
return result , err
1556
1566
}
1557
1567
1558
- func submitCallerIdentityRequest (method , endpoint string , parsedUrl * url.URL , body string , headers http.Header ) (* GetCallerIdentityResult , error ) {
1568
+ func submitCallerIdentityRequest (ctx context. Context , maxRetries int , method , endpoint string , parsedUrl * url.URL , body string , headers http.Header ) (* GetCallerIdentityResult , error ) {
1559
1569
// NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy
1560
1570
// The protection against this is that this method will only call the endpoint specified in the
1561
1571
// client config (defaulting to sts.amazonaws.com), so it would require a Vault admin to override
1562
1572
// the endpoint to talk to alternate web addresses
1563
1573
request := buildHttpRequest (method , endpoint , parsedUrl , body , headers )
1574
+ retryableReq , err := retryablehttp .FromRequest (request )
1575
+ if err != nil {
1576
+ return nil , err
1577
+ }
1578
+ retryableReq = retryableReq .WithContext (ctx )
1564
1579
client := cleanhttp .DefaultClient ()
1565
1580
client .CheckRedirect = func (req * http.Request , via []* http.Request ) error {
1566
1581
return http .ErrUseLastResponse
1567
1582
}
1583
+ retryingClient := & retryablehttp.Client {
1584
+ HTTPClient : client ,
1585
+ RetryWaitMin : retryWaitMin ,
1586
+ RetryWaitMax : retryWaitMax ,
1587
+ RetryMax : maxRetries ,
1588
+ CheckRetry : retryablehttp .DefaultRetryPolicy ,
1589
+ Backoff : retryablehttp .DefaultBackoff ,
1590
+ }
1568
1591
1569
- response , err := client .Do (request )
1592
+ response , err := retryingClient .Do (retryableReq )
1570
1593
if err != nil {
1571
1594
return nil , errwrap .Wrapf ("error making request: {{err}}" , err )
1572
1595
}
0 commit comments