Skip to content

Commit 56f090f

Browse files
authored
Fix race condition in endpoint discovery (#1504)
1 parent a6f44ed commit 56f090f

File tree

5 files changed

+214
-3
lines changed

5 files changed

+214
-3
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "9667162d-c94c-43c7-bb0c-5b0bbcd5ef8a",
3+
"type": "bugfix",
4+
"description": "Fixed a race condition that caused concurrent calls relying on endpoint discovery to share the same `url.URL` reference in their operation's http.Request.",
5+
"modules": [
6+
"service/internal/endpoint-discovery"
7+
]
8+
}

service/internal/endpoint-discovery/cache.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ func (c *EndpointCache) get(endpointKey string) (Endpoint, bool) {
3535
return Endpoint{}, false
3636
}
3737

38-
c.endpoints.Store(endpointKey, endpoint)
38+
ev := endpoint.(Endpoint)
39+
ev.Prune()
40+
41+
c.endpoints.Store(endpointKey, ev)
3942
return endpoint.(Endpoint), true
4043
}
4144

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package endpointdiscovery
2+
3+
import (
4+
"net/url"
5+
"testing"
6+
"time"
7+
)
8+
9+
func TestEndpointCache_Get_prune(t *testing.T) {
10+
c := NewEndpointCache(2)
11+
c.Add(Endpoint{
12+
Key: "foo",
13+
Addresses: []WeightedAddress{
14+
{
15+
URL: &url.URL{
16+
Host: "foo.amazonaws.com",
17+
},
18+
Expired: time.Now().Add(5 * time.Minute),
19+
},
20+
{
21+
URL: &url.URL{
22+
Host: "bar.amazonaws.com",
23+
},
24+
Expired: time.Now().Add(5 * -time.Minute),
25+
},
26+
},
27+
})
28+
29+
load, _ := c.endpoints.Load("foo")
30+
if ev := load.(Endpoint); len(ev.Addresses) != 2 {
31+
t.Errorf("expected two weighted addresses")
32+
}
33+
34+
weightedAddress, ok := c.Get("foo")
35+
if !ok {
36+
t.Errorf("expect weighted address, got none")
37+
}
38+
if e, a := "foo.amazonaws.com", weightedAddress.URL.Host; e != a {
39+
t.Errorf("expect %v, got %v", e, a)
40+
}
41+
42+
load, _ = c.endpoints.Load("foo")
43+
if ev := load.(Endpoint); len(ev.Addresses) != 1 {
44+
t.Errorf("expected one weighted address")
45+
}
46+
}

service/internal/endpoint-discovery/endpoint.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,44 @@ func (e *Endpoint) GetValidAddress() (WeightedAddress, bool) {
5151
we := e.Addresses[i]
5252

5353
if we.HasExpired() {
54-
e.Addresses = append(e.Addresses[:i], e.Addresses[i+1:]...)
55-
i--
5654
continue
5755
}
5856

57+
we.URL = cloneURL(we.URL)
58+
5959
return we, true
6060
}
6161

6262
return WeightedAddress{}, false
6363
}
64+
65+
// Prune will prune the expired addresses from the endpoint by allocating a new []WeightAddress.
66+
// This is not concurrent safe, and should be called from a single owning thread.
67+
func (e *Endpoint) Prune() bool {
68+
validLen := e.Len()
69+
if validLen == len(e.Addresses) {
70+
return false
71+
}
72+
wa := make([]WeightedAddress, 0, validLen)
73+
for i := range e.Addresses {
74+
if e.Addresses[i].HasExpired() {
75+
continue
76+
}
77+
wa = append(wa, e.Addresses[i])
78+
}
79+
e.Addresses = wa
80+
return true
81+
}
82+
83+
func cloneURL(u *url.URL) (clone *url.URL) {
84+
clone = &url.URL{}
85+
86+
*clone = *u
87+
88+
if u.User != nil {
89+
user := *u.User
90+
clone.User = &user
91+
}
92+
93+
return clone
94+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package endpointdiscovery
2+
3+
import (
4+
"net/url"
5+
"reflect"
6+
"strconv"
7+
"testing"
8+
"time"
9+
)
10+
11+
func Test_cloneURL(t *testing.T) {
12+
tests := []struct {
13+
value *url.URL
14+
wantClone *url.URL
15+
}{
16+
{
17+
value: &url.URL{
18+
Scheme: "https",
19+
Opaque: "foo",
20+
User: nil,
21+
Host: "amazonaws.com",
22+
Path: "/",
23+
RawPath: "/",
24+
ForceQuery: true,
25+
RawQuery: "thing=value",
26+
Fragment: "1234",
27+
RawFragment: "1234",
28+
},
29+
wantClone: &url.URL{
30+
Scheme: "https",
31+
Opaque: "foo",
32+
User: nil,
33+
Host: "amazonaws.com",
34+
Path: "/",
35+
RawPath: "/",
36+
ForceQuery: true,
37+
RawQuery: "thing=value",
38+
Fragment: "1234",
39+
RawFragment: "1234",
40+
},
41+
},
42+
{
43+
value: &url.URL{
44+
Scheme: "https",
45+
Opaque: "foo",
46+
User: url.UserPassword("NOT", "VALID"),
47+
Host: "amazonaws.com",
48+
Path: "/",
49+
RawPath: "/",
50+
ForceQuery: true,
51+
RawQuery: "thing=value",
52+
Fragment: "1234",
53+
RawFragment: "1234",
54+
},
55+
wantClone: &url.URL{
56+
Scheme: "https",
57+
Opaque: "foo",
58+
User: url.UserPassword("NOT", "VALID"),
59+
Host: "amazonaws.com",
60+
Path: "/",
61+
RawPath: "/",
62+
ForceQuery: true,
63+
RawQuery: "thing=value",
64+
Fragment: "1234",
65+
RawFragment: "1234",
66+
},
67+
},
68+
}
69+
for i, tt := range tests {
70+
t.Run(strconv.Itoa(i), func(t *testing.T) {
71+
gotClone := cloneURL(tt.value)
72+
if gotClone == tt.value {
73+
t.Errorf("expct clone URL to not be same pointer address")
74+
}
75+
if tt.value.User != nil {
76+
if tt.value.User == gotClone.User {
77+
t.Errorf("expct cloned Userinfo to not be same pointer address")
78+
}
79+
}
80+
if !reflect.DeepEqual(gotClone, tt.wantClone) {
81+
t.Errorf("cloneURL() = %v, want %v", gotClone, tt.wantClone)
82+
}
83+
})
84+
}
85+
}
86+
87+
func TestEndpoint_Prune(t *testing.T) {
88+
endpoint := Endpoint{}
89+
90+
endpoint.Add(WeightedAddress{
91+
URL: &url.URL{},
92+
Expired: time.Now().Add(5 * time.Minute),
93+
})
94+
95+
initial := endpoint.Addresses
96+
97+
if e, a := false, endpoint.Prune(); e != a {
98+
t.Errorf("expect prune %v, got %v", e, a)
99+
}
100+
101+
if e, a := &initial[0], &endpoint.Addresses[0]; e != a {
102+
t.Errorf("expect slice address to be same")
103+
}
104+
105+
endpoint.Add(WeightedAddress{
106+
URL: &url.URL{},
107+
Expired: time.Now().Add(5 * -time.Minute),
108+
})
109+
110+
initial = endpoint.Addresses
111+
112+
if e, a := true, endpoint.Prune(); e != a {
113+
t.Errorf("expect prune %v, got %v", e, a)
114+
}
115+
116+
if e, a := &initial[0], &endpoint.Addresses[0]; e == a {
117+
t.Errorf("expect slice address to be different")
118+
}
119+
120+
if e, a := 1, endpoint.Len(); e != a {
121+
t.Errorf("expect slice length %v, got %v", e, a)
122+
}
123+
}

0 commit comments

Comments
 (0)