Skip to content

Commit 3618d3f

Browse files
authored
feat(auth): add UniverseDomain to DetectOptions (#9536)
* Enable universe domain mismatch checks in transport packages
1 parent cc64719 commit 3618d3f

19 files changed

Lines changed: 496 additions & 111 deletions

File tree

auth/auth.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -284,18 +284,18 @@ type Error struct {
284284
uri string
285285
}
286286

287-
func (r *Error) Error() string {
288-
if r.code != "" {
289-
s := fmt.Sprintf("auth: %q", r.code)
290-
if r.description != "" {
291-
s += fmt.Sprintf(" %q", r.description)
287+
func (e *Error) Error() string {
288+
if e.code != "" {
289+
s := fmt.Sprintf("auth: %q", e.code)
290+
if e.description != "" {
291+
s += fmt.Sprintf(" %q", e.description)
292292
}
293-
if r.uri != "" {
294-
s += fmt.Sprintf(" %q", r.uri)
293+
if e.uri != "" {
294+
s += fmt.Sprintf(" %q", e.uri)
295295
}
296296
return s
297297
}
298-
return fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", r.Response.StatusCode, r.Body)
298+
return fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", e.Response.StatusCode, e.Body)
299299
}
300300

301301
// Temporary returns true if the error is considered temporary and may be able

auth/credentials/detect.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ func DetectDefault(opts *DetectOptions) (*auth.Credentials, error) {
9393
ProjectIDProvider: auth.CredentialsPropertyFunc(func(context.Context) (string, error) {
9494
return metadata.ProjectID()
9595
}),
96+
UniverseDomainProvider: &internal.ComputeUniverseDomainProvider{},
9697
}), nil
9798
}
9899

@@ -140,6 +141,9 @@ type DetectOptions struct {
140141
// Client configures the underlying client used to make network requests
141142
// when fetching tokens. Optional.
142143
Client *http.Client
144+
// UniverseDomain is the default service domain for a given Cloud universe.
145+
// The default value is "googleapis.com". Optional.
146+
UniverseDomain string
143147
}
144148

145149
func (o *DetectOptions) validate() error {

auth/credentials/detect_test.go

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ func TestDefaultCredentials_ExternalAccountKey(t *testing.T) {
596596
if want := "googleapis.com"; got != want {
597597
t.Fatalf("got %q, want %q", got, want)
598598
}
599-
tok, err := creds.Token(context.Background())
599+
tok, err := creds.Token(ctx)
600600
if err != nil {
601601
t.Fatalf("creds.Token() = %v", err)
602602
}
@@ -720,3 +720,111 @@ func TestDefaultCredentials_Validate(t *testing.T) {
720720
})
721721
}
722722
}
723+
724+
func TestDefaultCredentials_UniverseDomain(t *testing.T) {
725+
ctx := context.Background()
726+
tests := []struct {
727+
name string
728+
opts *DetectOptions
729+
want string
730+
}{
731+
{
732+
name: "user json",
733+
opts: &DetectOptions{
734+
CredentialsFile: "../internal/testdata/user.json",
735+
TokenURL: "example.com",
736+
},
737+
want: "googleapis.com",
738+
},
739+
{
740+
name: "user json with file universe domain",
741+
opts: &DetectOptions{
742+
CredentialsFile: "../internal/testdata/user_universe_domain.json",
743+
TokenURL: "example.com",
744+
},
745+
want: "googleapis.com",
746+
},
747+
{
748+
name: "service account token URL json",
749+
opts: &DetectOptions{
750+
CredentialsFile: "../internal/testdata/sa.json",
751+
},
752+
want: "googleapis.com",
753+
},
754+
{
755+
name: "external account json",
756+
opts: &DetectOptions{
757+
CredentialsFile: "../internal/testdata/exaccount_user.json",
758+
UseSelfSignedJWT: true,
759+
},
760+
want: "googleapis.com",
761+
},
762+
{
763+
name: "service account impersonation json",
764+
opts: &DetectOptions{
765+
CredentialsFile: "../internal/testdata/imp.json",
766+
UseSelfSignedJWT: true,
767+
},
768+
want: "googleapis.com",
769+
},
770+
{
771+
name: "service account json with file universe domain",
772+
opts: &DetectOptions{
773+
CredentialsFile: "../internal/testdata/sa_universe_domain.json",
774+
UseSelfSignedJWT: true,
775+
},
776+
want: "example.com",
777+
},
778+
{
779+
name: "service account json with options universe domain",
780+
opts: &DetectOptions{
781+
CredentialsFile: "../internal/testdata/sa.json",
782+
UseSelfSignedJWT: true,
783+
UniverseDomain: "foo.com",
784+
},
785+
want: "foo.com",
786+
},
787+
{
788+
name: "service account json with file and options universe domain",
789+
opts: &DetectOptions{
790+
CredentialsFile: "../internal/testdata/sa_universe_domain.json",
791+
UseSelfSignedJWT: true,
792+
UniverseDomain: "bar.com",
793+
},
794+
want: "bar.com",
795+
},
796+
{
797+
name: "external account json with options universe domain",
798+
opts: &DetectOptions{
799+
CredentialsFile: "../internal/testdata/exaccount_user.json",
800+
UseSelfSignedJWT: true,
801+
UniverseDomain: "foo.com",
802+
},
803+
want: "foo.com",
804+
},
805+
{
806+
name: "impersonated service account json with options universe domain",
807+
opts: &DetectOptions{
808+
CredentialsFile: "../internal/testdata/imp.json",
809+
UseSelfSignedJWT: true,
810+
UniverseDomain: "foo.com",
811+
},
812+
want: "foo.com",
813+
},
814+
}
815+
for _, tt := range tests {
816+
t.Run(tt.name, func(t *testing.T) {
817+
creds, err := DetectDefault(tt.opts)
818+
if err != nil {
819+
t.Fatalf("%s: %v", tt.name, err)
820+
}
821+
ud, err := creds.UniverseDomain(ctx)
822+
if err != nil {
823+
t.Fatalf("%s: %v", tt.name, err)
824+
}
825+
if ud != tt.want {
826+
t.Fatalf("%s: got %q, want %q", tt.name, ud, tt.want)
827+
}
828+
})
829+
}
830+
}

auth/credentials/filetypes.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ func fileCredentials(b []byte, opts *DetectOptions) (*auth.Credentials, error) {
101101
default:
102102
return nil, fmt.Errorf("detect: unsupported filetype %q", fileType)
103103
}
104+
if opts.UniverseDomain != "" {
105+
universeDomain = opts.UniverseDomain
106+
}
104107
return auth.NewCredentials(&auth.CredentialsOptions{
105108
TokenProvider: auth.NewCachedTokenProvider(tp, &auth.CachedTokenProviderOptions{
106109
ExpireEarly: opts.EarlyTokenRefresh,

auth/credentials/internal/externalaccount/aws_provider.go

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -173,17 +173,17 @@ func (sp *awsSubjectProvider) providerType() string {
173173
return awsProviderType
174174
}
175175

176-
func (cs *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, error) {
177-
if cs.IMDSv2SessionTokenURL == "" {
176+
func (sp *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, error) {
177+
if sp.IMDSv2SessionTokenURL == "" {
178178
return "", nil
179179
}
180-
req, err := http.NewRequestWithContext(ctx, "PUT", cs.IMDSv2SessionTokenURL, nil)
180+
req, err := http.NewRequestWithContext(ctx, "PUT", sp.IMDSv2SessionTokenURL, nil)
181181
if err != nil {
182182
return "", err
183183
}
184184
req.Header.Set(awsIMDSv2SessionTTLHeader, awsIMDSv2SessionTTL)
185185

186-
resp, err := cs.Client.Do(req)
186+
resp, err := sp.Client.Do(req)
187187
if err != nil {
188188
return "", err
189189
}
@@ -199,19 +199,19 @@ func (cs *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, e
199199
return string(respBody), nil
200200
}
201201

202-
func (cs *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]string) (string, error) {
202+
func (sp *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]string) (string, error) {
203203
if canRetrieveRegionFromEnvironment() {
204204
if envAwsRegion := getenv(awsRegionEnvVar); envAwsRegion != "" {
205205
return envAwsRegion, nil
206206
}
207207
return getenv(awsDefaultRegionEnvVar), nil
208208
}
209209

210-
if cs.RegionURL == "" {
210+
if sp.RegionURL == "" {
211211
return "", errors.New("detect: unable to determine AWS region")
212212
}
213213

214-
req, err := http.NewRequestWithContext(ctx, "GET", cs.RegionURL, nil)
214+
req, err := http.NewRequestWithContext(ctx, "GET", sp.RegionURL, nil)
215215
if err != nil {
216216
return "", err
217217
}
@@ -220,7 +220,7 @@ func (cs *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]
220220
req.Header.Add(name, value)
221221
}
222222

223-
resp, err := cs.Client.Do(req)
223+
resp, err := sp.Client.Do(req)
224224
if err != nil {
225225
return "", err
226226
}
@@ -244,7 +244,7 @@ func (cs *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]
244244
return string(respBody[:bodyLen-1]), nil
245245
}
246246

247-
func (cs *awsSubjectProvider) getSecurityCredentials(ctx context.Context, headers map[string]string) (result awsSecurityCredentials, err error) {
247+
func (sp *awsSubjectProvider) getSecurityCredentials(ctx context.Context, headers map[string]string) (result awsSecurityCredentials, err error) {
248248
if canRetrieveSecurityCredentialFromEnvironment() {
249249
return awsSecurityCredentials{
250250
AccessKeyID: getenv(awsAccessKeyIDEnvVar),
@@ -253,11 +253,11 @@ func (cs *awsSubjectProvider) getSecurityCredentials(ctx context.Context, header
253253
}, nil
254254
}
255255

256-
roleName, err := cs.getMetadataRoleName(ctx, headers)
256+
roleName, err := sp.getMetadataRoleName(ctx, headers)
257257
if err != nil {
258258
return
259259
}
260-
credentials, err := cs.getMetadataSecurityCredentials(ctx, roleName, headers)
260+
credentials, err := sp.getMetadataSecurityCredentials(ctx, roleName, headers)
261261
if err != nil {
262262
return
263263
}
@@ -272,18 +272,18 @@ func (cs *awsSubjectProvider) getSecurityCredentials(ctx context.Context, header
272272
return credentials, nil
273273
}
274274

275-
func (cs *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context, roleName string, headers map[string]string) (awsSecurityCredentials, error) {
275+
func (sp *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context, roleName string, headers map[string]string) (awsSecurityCredentials, error) {
276276
var result awsSecurityCredentials
277277

278-
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
278+
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/%s", sp.CredVerificationURL, roleName), nil)
279279
if err != nil {
280280
return result, err
281281
}
282282
for name, value := range headers {
283283
req.Header.Add(name, value)
284284
}
285285

286-
resp, err := cs.Client.Do(req)
286+
resp, err := sp.Client.Do(req)
287287
if err != nil {
288288
return result, err
289289
}
@@ -300,19 +300,19 @@ func (cs *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context
300300
return result, err
301301
}
302302

303-
func (cs *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers map[string]string) (string, error) {
304-
if cs.CredVerificationURL == "" {
303+
func (sp *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers map[string]string) (string, error) {
304+
if sp.CredVerificationURL == "" {
305305
return "", errors.New("detect: unable to determine the AWS metadata server security credentials endpoint")
306306
}
307-
req, err := http.NewRequestWithContext(ctx, "GET", cs.CredVerificationURL, nil)
307+
req, err := http.NewRequestWithContext(ctx, "GET", sp.CredVerificationURL, nil)
308308
if err != nil {
309309
return "", err
310310
}
311311
for name, value := range headers {
312312
req.Header.Add(name, value)
313313
}
314314

315-
resp, err := cs.Client.Do(req)
315+
resp, err := sp.Client.Do(req)
316316
if err != nil {
317317
return "", err
318318
}

auth/credentials/internal/externalaccount/executable_provider.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ type executableResponse struct {
122122
Message string `json:"message,omitempty"`
123123
}
124124

125-
func (cs *executableSubjectProvider) parseSubjectTokenFromSource(response []byte, source string, now int64) (string, error) {
125+
func (sp *executableSubjectProvider) parseSubjectTokenFromSource(response []byte, source string, now int64) (string, error) {
126126
var result executableResponse
127127
if err := json.Unmarshal(response, &result); err != nil {
128128
return "", jsonParsingError(source, string(response))
@@ -143,7 +143,7 @@ func (cs *executableSubjectProvider) parseSubjectTokenFromSource(response []byte
143143
if result.Version > executableSupportedMaxVersion || result.Version < 0 {
144144
return "", unsupportedVersionError(source, result.Version)
145145
}
146-
if result.ExpirationTime == 0 && cs.OutputFile != "" {
146+
if result.ExpirationTime == 0 && sp.OutputFile != "" {
147147
return "", missingFieldError(source, "expiration_time")
148148
}
149149
if result.TokenType == "" {
@@ -169,24 +169,24 @@ func (cs *executableSubjectProvider) parseSubjectTokenFromSource(response []byte
169169
}
170170
}
171171

172-
func (cs *executableSubjectProvider) subjectToken(ctx context.Context) (string, error) {
173-
if token, err := cs.getTokenFromOutputFile(); token != "" || err != nil {
172+
func (sp *executableSubjectProvider) subjectToken(ctx context.Context) (string, error) {
173+
if token, err := sp.getTokenFromOutputFile(); token != "" || err != nil {
174174
return token, err
175175
}
176-
return cs.getTokenFromExecutableCommand(ctx)
176+
return sp.getTokenFromExecutableCommand(ctx)
177177
}
178178

179-
func (cs *executableSubjectProvider) providerType() string {
179+
func (sp *executableSubjectProvider) providerType() string {
180180
return executableProviderType
181181
}
182182

183-
func (cs *executableSubjectProvider) getTokenFromOutputFile() (token string, err error) {
184-
if cs.OutputFile == "" {
183+
func (sp *executableSubjectProvider) getTokenFromOutputFile() (token string, err error) {
184+
if sp.OutputFile == "" {
185185
// This ExecutableCredentialSource doesn't use an OutputFile.
186186
return "", nil
187187
}
188188

189-
file, err := os.Open(cs.OutputFile)
189+
file, err := os.Open(sp.OutputFile)
190190
if err != nil {
191191
// No OutputFile found. Hasn't been created yet, so skip it.
192192
return "", nil
@@ -199,7 +199,7 @@ func (cs *executableSubjectProvider) getTokenFromOutputFile() (token string, err
199199
return "", nil
200200
}
201201

202-
token, err = cs.parseSubjectTokenFromSource(data, outputFileSource, cs.env.now().Unix())
202+
token, err = sp.parseSubjectTokenFromSource(data, outputFileSource, sp.env.now().Unix())
203203
if err != nil {
204204
if _, ok := err.(nonCacheableError); ok {
205205
// If the cached token is expired we need a new token,
@@ -231,20 +231,20 @@ func (sp *executableSubjectProvider) executableEnvironment() []string {
231231
return result
232232
}
233233

234-
func (cs *executableSubjectProvider) getTokenFromExecutableCommand(ctx context.Context) (string, error) {
234+
func (sp *executableSubjectProvider) getTokenFromExecutableCommand(ctx context.Context) (string, error) {
235235
// For security reasons, we need our consumers to set this environment variable to allow executables to be run.
236-
if cs.env.getenv(allowExecutablesEnvVar) != "1" {
236+
if sp.env.getenv(allowExecutablesEnvVar) != "1" {
237237
return "", errors.New("detect: executables need to be explicitly allowed (set GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES to '1') to run")
238238
}
239239

240-
ctx, cancel := context.WithDeadline(ctx, cs.env.now().Add(cs.Timeout))
240+
ctx, cancel := context.WithDeadline(ctx, sp.env.now().Add(sp.Timeout))
241241
defer cancel()
242242

243-
output, err := cs.env.run(ctx, cs.Command, cs.executableEnvironment())
243+
output, err := sp.env.run(ctx, sp.Command, sp.executableEnvironment())
244244
if err != nil {
245245
return "", err
246246
}
247-
return cs.parseSubjectTokenFromSource(output, executableSource, cs.env.now().Unix())
247+
return sp.parseSubjectTokenFromSource(output, executableSource, sp.env.now().Unix())
248248
}
249249

250250
func missingFieldError(source, field string) error {

0 commit comments

Comments
 (0)