Skip to content

Commit 4cdb267

Browse files
committed
feat: add rate limit for request paypal api
1 parent 5498dc5 commit 4cdb267

File tree

3 files changed

+181
-2
lines changed

3 files changed

+181
-2
lines changed

client.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@ import (
1010
"net/http"
1111
"net/http/httputil"
1212
"time"
13+
14+
"github.com/plutov/paypal/v4/limiter"
1315
)
1416

17+
var ErrRateLimited = errors.New("rate limited")
18+
1519
// NewClient returns new Client struct
1620
// APIBase is a base API URL, for testing you can use paypal.APIBaseSandBox
1721
func NewClient(clientID string, secret string, APIBase string) (*Client, error) {
@@ -76,6 +80,15 @@ func (c *Client) SetReturnRepresentation() {
7680
c.returnRepresentation = true
7781
}
7882

83+
// SetLimiter configures an optional rate limiter used by SendWithAuth.
84+
// key is used as the rate limit bucket key; if empty, a default is used.
85+
func (c *Client) SetLimiter(l limiter.RateLimiter, key string) {
86+
c.mu.Lock()
87+
defer c.mu.Unlock()
88+
c.rateLimiter = l
89+
c.rateLimiterKey = key
90+
}
91+
7992
// Send makes a request to the API, the response body will be
8093
// unmarshalled into v, or if v is an io.Writer, the response will
8194
// be written to it without decoding
@@ -166,7 +179,34 @@ func (c *Client) SendWithAuth(req *http.Request, v interface{}) error {
166179
// Note: Here we do not want to `defer c.Unlock()` because we need `c.Send(...)`
167180
// to happen outside of the locked section.
168181

169-
if c.Token == nil || (!c.tokenExpiresAt.IsZero() && time.Until(c.tokenExpiresAt) < RequestNewTokenBeforeExpiresIn) {
182+
// determine if a token request is needed under the lock
183+
needToken := c.Token == nil || (!c.tokenExpiresAt.IsZero() && time.Until(c.tokenExpiresAt) < RequestNewTokenBeforeExpiresIn)
184+
185+
// optional rate limiting
186+
if c.rateLimiter != nil {
187+
key := c.rateLimiterKey
188+
if key == "" {
189+
key = "paypal:client"
190+
}
191+
ctx := req.Context()
192+
permits := 1
193+
if needToken {
194+
permits = 2 // one for token request, one for the main API call
195+
}
196+
for i := 0; i < permits; i++ {
197+
dec, err := c.rateLimiter.Allow(ctx, key)
198+
if err != nil {
199+
c.mu.Unlock()
200+
return err
201+
}
202+
if !dec.Allowed {
203+
c.mu.Unlock()
204+
return ErrRateLimited
205+
}
206+
}
207+
}
208+
209+
if needToken {
170210
// c.Token will be updated in GetAccessToken call
171211
if _, err := c.GetAccessToken(req.Context()); err != nil {
172212
// c.Unlock()

client_limiter_test.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package paypal
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
8+
"sync/atomic"
9+
"testing"
10+
"time"
11+
12+
"github.com/plutov/paypal/v4/limiter"
13+
"github.com/stretchr/testify/assert"
14+
)
15+
16+
type okResp struct {
17+
OK bool `json:"ok"`
18+
}
19+
20+
func newTestServer() (*httptest.Server, *int32, *int32) {
21+
tokenHits := new(int32)
22+
apiHits := new(int32)
23+
mux := http.NewServeMux()
24+
mux.HandleFunc("/v1/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
25+
atomic.AddInt32(tokenHits, 1)
26+
w.Header().Set("Content-Type", "application/json")
27+
_ = json.NewEncoder(w).Encode(map[string]any{
28+
"access_token": "test-token",
29+
"token_type": "Bearer",
30+
"expires_in": 3600,
31+
})
32+
})
33+
mux.HandleFunc("/api", func(w http.ResponseWriter, r *http.Request) {
34+
atomic.AddInt32(apiHits, 1)
35+
if r.Header.Get("Authorization") == "" {
36+
w.WriteHeader(http.StatusUnauthorized)
37+
return
38+
}
39+
w.Header().Set("Content-Type", "application/json")
40+
_ = json.NewEncoder(w).Encode(okResp{OK: true})
41+
})
42+
return httptest.NewServer(mux), tokenHits, apiHits
43+
}
44+
45+
func newClientForServer(t *testing.T, s *httptest.Server) *Client {
46+
t.Helper()
47+
c, err := NewClient("id", "secret", s.URL)
48+
if err != nil {
49+
t.Fatalf("new client: %v", err)
50+
}
51+
c.SetHTTPClient(s.Client())
52+
return c
53+
}
54+
55+
func TestClientLimiter_BlocksWhenInsufficientPermitsForTokenAndRequest(t *testing.T) {
56+
s, tokenHits, apiHits := newTestServer()
57+
defer s.Close()
58+
c := newClientForServer(t, s)
59+
60+
// limiter allows only 1 request, but we need 2 (token + api)
61+
lim := limiter.NewFixedWindowLimiter(limiter.NewMemoryStorage(), limiter.FixedWindowConfig{Window: time.Minute, Limit: 1})
62+
c.SetLimiter(lim, "test-key")
63+
64+
req, _ := c.NewRequest(context.Background(), http.MethodGet, s.URL+"/api", nil)
65+
var out okResp
66+
err := c.SendWithAuth(req, &out)
67+
assert.Error(t, err)
68+
assert.Equal(t, ErrRateLimited, err)
69+
// should have short-circuited before hitting network
70+
assert.Equal(t, int32(0), atomic.LoadInt32(tokenHits))
71+
assert.Equal(t, int32(0), atomic.LoadInt32(apiHits))
72+
}
73+
74+
func TestClientLimiter_AllowsWithSufficientPermitsForTokenAndRequest(t *testing.T) {
75+
s, tokenHits, apiHits := newTestServer()
76+
defer s.Close()
77+
c := newClientForServer(t, s)
78+
79+
// need 2 permits and we have 2
80+
lim := limiter.NewFixedWindowLimiter(limiter.NewMemoryStorage(), limiter.FixedWindowConfig{Window: time.Minute, Limit: 2})
81+
c.SetLimiter(lim, "test-key")
82+
83+
req, _ := c.NewRequest(context.Background(), http.MethodGet, s.URL+"/api", nil)
84+
var out okResp
85+
err := c.SendWithAuth(req, &out)
86+
assert.NoError(t, err)
87+
assert.True(t, out.OK)
88+
assert.Equal(t, int32(1), atomic.LoadInt32(tokenHits))
89+
assert.Equal(t, int32(1), atomic.LoadInt32(apiHits))
90+
}
91+
92+
func TestClientLimiter_ConsumesOnePermitWhenTokenAlreadySet(t *testing.T) {
93+
s, tokenHits, apiHits := newTestServer()
94+
defer s.Close()
95+
c := newClientForServer(t, s)
96+
97+
// set token so we don't need refresh
98+
c.SetAccessToken("preset")
99+
// allow only 1 permit
100+
lim := limiter.NewFixedWindowLimiter(limiter.NewMemoryStorage(), limiter.FixedWindowConfig{Window: time.Minute, Limit: 1})
101+
c.SetLimiter(lim, "test-key")
102+
103+
req, _ := c.NewRequest(context.Background(), http.MethodGet, s.URL+"/api", nil)
104+
var out okResp
105+
err := c.SendWithAuth(req, &out)
106+
assert.NoError(t, err)
107+
assert.True(t, out.OK)
108+
assert.Equal(t, int32(0), atomic.LoadInt32(*&tokenHits))
109+
assert.Equal(t, int32(1), atomic.LoadInt32(apiHits))
110+
}
111+
112+
func TestClientLimiter_TokenSoonToExpire_ConsumesTwoPermits(t *testing.T) {
113+
s, tokenHits, apiHits := newTestServer()
114+
defer s.Close()
115+
c := newClientForServer(t, s)
116+
117+
// pre-existing token but expiring soon (< 60s)
118+
c.Token = &TokenResponse{Token: "old"}
119+
c.tokenExpiresAt = time.Now().Add(5 * time.Second)
120+
121+
// need 2 permits; provide exactly 2
122+
lim := limiter.NewFixedWindowLimiter(limiter.NewMemoryStorage(), limiter.FixedWindowConfig{Window: time.Minute, Limit: 2})
123+
c.SetLimiter(lim, "test-key")
124+
125+
req, _ := c.NewRequest(context.Background(), http.MethodGet, s.URL+"/api", nil)
126+
var out okResp
127+
err := c.SendWithAuth(req, &out)
128+
assert.NoError(t, err)
129+
assert.True(t, out.OK)
130+
assert.Equal(t, int32(1), atomic.LoadInt32(tokenHits))
131+
assert.Equal(t, int32(1), atomic.LoadInt32(apiHits))
132+
}

types.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ import (
88
"strings"
99
"sync"
1010
"time"
11+
12+
// add limiter support
13+
"github.com/plutov/paypal/v4/limiter"
1114
)
1215

1316
const (
@@ -646,6 +649,10 @@ type (
646649
Token *TokenResponse
647650
tokenExpiresAt time.Time
648651
returnRepresentation bool
652+
653+
// optional rate limiter; if set, SendWithAuth will check it before making requests
654+
rateLimiter limiter.RateLimiter
655+
rateLimiterKey string
649656
}
650657

651658
// CreditCard struct
@@ -1594,7 +1601,7 @@ func (t JSONTime) MarshalJSON() ([]byte, error) {
15941601
return []byte(stamp), nil
15951602
}
15961603

1597-
// UnmarshalJSON for JSONTime, timezone offset is missing a colon ':"
1604+
// UnmarshalJSON for JSONTime, timezone offset is missing a colon ':'
15981605
func (t *JSONTime) UnmarshalJSON(b []byte) error {
15991606
s := strings.Trim(string(b), `"`)
16001607
nt, err := time.Parse("2006-01-02T15:04:05Z0700", s)

0 commit comments

Comments
 (0)