@@ -2,52 +2,57 @@ package awssign
2
2
3
3
import (
4
4
"bytes"
5
+ "context"
6
+ "crypto/sha256"
7
+ "encoding/base64"
5
8
"fmt"
6
9
"io"
7
10
"net/http"
8
11
"os"
9
12
"time"
10
13
11
- "github.com/aws/aws-sdk-go/aws/credentials"
12
- v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
13
- "github.com/stackql/any-sdk/pkg/logging"
14
+ v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
15
+ "github.com/aws/aws-sdk-go-v2/credentials"
14
16
)
15
17
16
18
var (
17
- _ Transport = & standardAwsSignTransport {}
19
+ _ Transport = & standardAwsSignTransport {}
20
+ emptyPayloadHash string = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
18
21
)
19
22
20
23
type Transport interface {
21
24
RoundTrip (req * http.Request ) (* http.Response , error )
22
25
}
23
26
24
27
type standardAwsSignTransport struct {
25
- underlyingTransport http.RoundTripper
26
- signer * v4.Signer
28
+ underlyingTransport http.RoundTripper
29
+ signer * v4.Signer
30
+ staticCredentialsProvider credentials.StaticCredentialsProvider
27
31
}
28
32
29
33
func NewAwsSignTransport (
30
34
underlyingTransport http.RoundTripper ,
31
35
id , secret , token string ,
32
- options ... func (* v4.Signer ),
36
+ options ... func (* v4.SignerOptions ),
33
37
) (Transport , error ) {
34
- var creds * credentials.Credentials
38
+ var creds credentials.StaticCredentialsProvider
35
39
36
40
if token == "" {
37
- creds = credentials .NewStaticCredentials (id , secret , token )
41
+ creds = credentials .NewStaticCredentialsProvider (id , secret , token )
38
42
} else {
39
43
defaultAccessKeyID := os .Getenv ("AWS_ACCESS_KEY_ID" )
40
44
defaultSecretAccessKey := os .Getenv ("AWS_SECRET_ACCESS_KEY" )
41
45
if defaultAccessKeyID == "" || defaultSecretAccessKey == "" {
42
46
return nil , fmt .Errorf ("AWS_SESSION_TOKEN is set, but AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY must also be set" )
43
47
}
44
- creds = credentials .NewEnvCredentials ( )
48
+ creds = credentials .NewStaticCredentialsProvider ( defaultAccessKeyID , defaultSecretAccessKey , token )
45
49
}
46
50
47
- signer := v4 .NewSigner (creds , options ... )
51
+ signer := v4 .NewSigner (options ... )
48
52
return & standardAwsSignTransport {
49
- underlyingTransport : underlyingTransport ,
50
- signer : signer ,
53
+ underlyingTransport : underlyingTransport ,
54
+ signer : signer ,
55
+ staticCredentialsProvider : creds ,
51
56
}, nil
52
57
}
53
58
@@ -68,23 +73,34 @@ func (t *standardAwsSignTransport) RoundTrip(req *http.Request) (*http.Response,
68
73
if ! ok {
69
74
return nil , fmt .Errorf ("unsupported type for AWS region: '%T'" , rgn )
70
75
}
71
- var rs io.ReadSeeker
76
+ creds , credsErr := t .staticCredentialsProvider .Retrieve (context .TODO ())
77
+ if credsErr != nil {
78
+ return nil , credsErr
79
+ }
80
+
81
+ var payloadHash string
72
82
if req .Body != nil {
73
83
body , err := io .ReadAll (req .Body )
74
84
if err != nil {
75
85
return nil , err
76
86
}
77
- rs = bytes .NewReader (body )
78
- req .Body = nil
87
+ hashBytes := sha256 .Sum256 (body )
88
+ // Base64 encode the hash
89
+ payloadHash = base64 .StdEncoding .EncodeToString (hashBytes [:])
90
+ rs := io .NopCloser (bytes .NewReader (body ))
91
+ req .Body = rs
92
+ } else {
93
+ payloadHash = emptyPayloadHash
79
94
}
80
- header , err := t .signer .Sign (
95
+ err := t .signer .SignHTTP (
96
+ context .TODO (),
97
+ creds ,
81
98
req ,
82
- rs ,
99
+ payloadHash ,
83
100
svcStr ,
84
101
rgnStr ,
85
102
time .Now (),
86
103
)
87
- logging .GetLogger ().Infof ("header = %v\n " , header )
88
104
if err != nil {
89
105
return nil , err
90
106
}
0 commit comments