Skip to content

Commit fa3bfdb

Browse files
authored
fix(auth): handle non-Transport DefaultTransport (#10162)
Since `http.DefaultTransport` is a `RoundTripper` interface and mutable global variable, it is not safe to assume it is always going to concretely be `*http.Transport`. If it is not, I suppose we should just use the value directly instead of making a `Clone`. The `http.DefaultTransport` being overridden is pretty intentional by application authors, so it is best if we respect that and just reuse it direct. Fixes #10159
1 parent 1320d7d commit fa3bfdb

2 files changed

Lines changed: 30 additions & 1 deletion

File tree

auth/httptransport/httptransport.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,14 @@ func AddAuthorizationMiddleware(client *http.Client, creds *auth.Credentials) er
152152
}
153153
base := client.Transport
154154
if base == nil {
155-
base = http.DefaultTransport.(*http.Transport).Clone()
155+
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
156+
base = dt.Clone()
157+
} else {
158+
// Directly reuse the DefaultTransport if the application has
159+
// replaced it with an implementation of RoundTripper other than
160+
// http.Transport.
161+
base = http.DefaultTransport
162+
}
156163
}
157164
client.Transport = &authTransport{
158165
creds: creds,

auth/httptransport/httptransport_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,28 @@ func TestAddAuthorizationMiddleware(t *testing.T) {
8787
}
8888
}
8989

90+
func TestAddAuthorizationMiddleware_HandlesNonTransportAsDefaultTransport(t *testing.T) {
91+
client := &http.Client{}
92+
creds := auth.NewCredentials(&auth.CredentialsOptions{
93+
TokenProvider: staticTP("fakeToken"),
94+
})
95+
dt := http.DefaultTransport
96+
97+
http.DefaultTransport = &rt{}
98+
defer func() { http.DefaultTransport = dt }()
99+
100+
err := AddAuthorizationMiddleware(client, creds)
101+
if err != nil {
102+
t.Fatal(err)
103+
}
104+
105+
at := client.Transport.(*authTransport)
106+
_, ok := at.base.(*rt)
107+
if !ok {
108+
t.Errorf("got %T, want %T", at.base, &rt{})
109+
}
110+
}
111+
90112
func TestNewClient_FailsValidation(t *testing.T) {
91113
tests := []struct {
92114
name string

0 commit comments

Comments
 (0)