diff --git a/rueidislimiter/go.mod b/rueidislimiter/go.mod new file mode 100644 index 00000000..eba62ece --- /dev/null +++ b/rueidislimiter/go.mod @@ -0,0 +1,17 @@ +module github.com/redis/rueidis/rueidislimiter + +go 1.22.0 + +toolchain go1.24.1 + +replace github.com/redis/rueidis => ../ + +replace github.com/redis/rueidis/mock => ../mock + +require ( + github.com/redis/rueidis v1.0.55 + github.com/redis/rueidis/mock v1.0.55 + go.uber.org/mock v0.5.0 +) + +require golang.org/x/sys v0.30.0 // indirect diff --git a/rueidislimiter/go.sum b/rueidislimiter/go.sum new file mode 100644 index 00000000..043c318c --- /dev/null +++ b/rueidislimiter/go.sum @@ -0,0 +1,14 @@ +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= +github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/rueidislimiter/limiter.go b/rueidislimiter/limiter.go index 7ff4f028..ef387849 100644 --- a/rueidislimiter/limiter.go +++ b/rueidislimiter/limiter.go @@ -4,7 +4,6 @@ import ( "context" "errors" "strconv" - "strings" "time" "github.com/redis/rueidis" @@ -13,6 +12,9 @@ import ( var ( ErrInvalidTokens = errors.New("number of tokens must be non-negative") ErrInvalidResponse = errors.New("invalid response from Redis") + ErrInvalidLimit = errors.New("limit must be positive") + ErrInvalidWindow = errors.New("window must be positive") + ErrNilBuilder = errors.New("client builder is required") ) type Result struct { @@ -25,9 +27,14 @@ type RateLimiterClient interface { Check(ctx context.Context, identifier string, options ...RateLimitOption) (Result, error) Allow(ctx context.Context, identifier string, options ...RateLimitOption) (Result, error) AllowN(ctx context.Context, identifier string, n int64, options ...RateLimitOption) (Result, error) + Limit() int } -const PlaceholderPrefix = "rueidislimiter" +const ( + PlaceholderPrefix = "rueidislimiter" + keyDelimOpen = ":{" + keyDelimClose = "}" +) type rateLimiter struct { client rueidis.Client @@ -44,11 +51,11 @@ type RateLimiterOption struct { } func NewRateLimiter(option RateLimiterOption) (RateLimiterClient, error) { - if option.Window < time.Millisecond { - option.Window = time.Millisecond + if option.Window <= 0 { + return nil, ErrInvalidWindow } if option.Limit <= 0 { - option.Limit = 1 + return nil, ErrInvalidLimit } if option.KeyPrefix == "" { option.KeyPrefix = PlaceholderPrefix @@ -95,52 +102,60 @@ func (l *rateLimiter) AllowN(ctx context.Context, identifier string, n int64, op rl = options[len(options)-1] } + bufs := rateBuffersPool.Get(0, 128) + defer rateBuffersPool.Put(bufs) + now := time.Now().UTC() - keys := []string{l.getKey(identifier)} - args := []string{ - strconv.FormatInt(n, 10), - strconv.FormatInt(now.Add(rl.window).UnixMilli(), 10), - strconv.FormatInt(now.UnixMilli(), 10), - } - resp := rateLimitScript.Exec(ctx, l.client, keys, args) + offset := len(bufs.keyBuf) + bufs.keyBuf = append(bufs.keyBuf, l.keyPrefix...) + bufs.keyBuf = append(bufs.keyBuf, keyDelimOpen...) + bufs.keyBuf = append(bufs.keyBuf, identifier...) + bufs.keyBuf = append(bufs.keyBuf, keyDelimClose...) + key := rueidis.BinaryString(bufs.keyBuf[offset:]) + + offset = len(bufs.keyBuf) + bufs.keyBuf = strconv.AppendInt(bufs.keyBuf, n, 10) + arg1 := rueidis.BinaryString(bufs.keyBuf[offset:]) + + offset = len(bufs.keyBuf) + bufs.keyBuf = strconv.AppendInt(bufs.keyBuf, now.Add(rl.window).UnixMilli(), 10) + arg2 := rueidis.BinaryString(bufs.keyBuf[offset:]) + + offset = len(bufs.keyBuf) + bufs.keyBuf = strconv.AppendInt(bufs.keyBuf, now.UnixMilli(), 10) + arg3 := rueidis.BinaryString(bufs.keyBuf[offset:]) + + resp := rateLimitScript.Exec(ctx, l.client, []string{key}, []string{arg1, arg2, arg3}) if err := resp.Error(); err != nil { return Result{}, err } - data, err := resp.AsIntSlice() - if err != nil || len(data) != 2 { + arr, err := resp.ToArray() + if err != nil || len(arr) != 2 { return Result{}, ErrInvalidResponse } - current := data[0] - remaining := rl.limit - current - if remaining < 0 { - remaining = 0 + current, err := arr[0].ToInt64() + if err != nil { + return Result{}, ErrInvalidResponse } - allowed := current <= rl.limit - if n == 0 { - allowed = current < rl.limit + resetAt, err := arr[1].ToInt64() + if err != nil { + return Result{}, ErrInvalidResponse } + remaining := max(rl.limit-current, 0) + allowed := current <= rl.limit && (n > 0 || current < rl.limit) + return Result{ Allowed: allowed, Remaining: remaining, - ResetAtMs: data[1], + ResetAtMs: resetAt, }, nil } -func (l *rateLimiter) getKey(identifier string) string { - sb := strings.Builder{} - sb.Grow(len(l.keyPrefix) + len(identifier) + 3) - sb.WriteString(l.keyPrefix) - sb.WriteString(":{") - sb.WriteString(identifier) - sb.WriteString("}") - return sb.String() -} - var rateLimitScript = rueidis.NewLuaScript(` local rate_limit_key = KEYS[1] local increment_amount = tonumber(ARGV[1]) diff --git a/rueidislimiter/limiter_test.go b/rueidislimiter/limiter_test.go index c229f300..4dca6acc 100644 --- a/rueidislimiter/limiter_test.go +++ b/rueidislimiter/limiter_test.go @@ -2,229 +2,392 @@ package rueidislimiter_test import ( "context" - "encoding/binary" - "encoding/hex" - "math/rand" + "errors" "testing" "time" - "unsafe" "github.com/redis/rueidis" + "github.com/redis/rueidis/mock" "github.com/redis/rueidis/rueidislimiter" + "go.uber.org/mock/gomock" ) -func setup(t testing.TB) rueidis.Client { - client, err := rueidis.NewClient(rueidis.ClientOption{InitAddress: []string{"127.0.0.1:6379"}}) - if err != nil { - t.Fatal(err) +func TestNewRateLimiter(t *testing.T) { + tests := []struct { + name string + opt rueidislimiter.RateLimiterOption + wantErr error + }{ + { + name: "default values", + opt: rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return mock.NewClient(gomock.NewController(t)), nil + }, + Limit: 1, + Window: time.Second, + }, + }, + { + name: "custom values", + opt: rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return mock.NewClient(gomock.NewController(t)), nil + }, + Limit: 100, + Window: time.Second, + KeyPrefix: "test:", + }, + }, + { + name: "invalid window", + opt: rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return mock.NewClient(gomock.NewController(t)), nil + }, + Limit: 1, + Window: -time.Second, + }, + wantErr: rueidislimiter.ErrInvalidWindow, + }, + { + name: "invalid limit", + opt: rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return mock.NewClient(gomock.NewController(t)), nil + }, + Limit: -1, + Window: time.Second, + }, + wantErr: rueidislimiter.ErrInvalidLimit, + }, + { + name: "empty key prefix", + opt: rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return mock.NewClient(gomock.NewController(t)), nil + }, + Limit: 1, + Window: time.Second, + }, + }, + { + name: "nil client builder", + opt: rueidislimiter.RateLimiterOption{ + ClientOption: rueidis.ClientOption{InitAddress: []string{"127.0.0.1:6379"}}, + Limit: 1, + Window: time.Second, + }, + }, + { + name: "new client error", + opt: rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return nil, errors.New("client error") + }, + Limit: 1, + Window: time.Second, + }, + wantErr: errors.New("client error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := rueidislimiter.NewRateLimiter(tt.opt) + if tt.wantErr != nil { + if err == nil { + t.Fatalf("NewRateLimiter() error = nil, wantErr %v", tt.wantErr) + } + if err.Error() != tt.wantErr.Error() { + t.Fatalf("NewRateLimiter() error = %v, wantErr %v", err, tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("NewRateLimiter() error = %v, wantErr nil", err) + } + }) + } +} + +func TestRateLimiter_AllowN(t *testing.T) { + now := time.Now() + resetTime := now.Add(time.Second).UnixMilli() + + tests := []struct { + name string + mockResp rueidis.RedisResult + n int64 + customOpt *rueidislimiter.RateLimitOption + wantResult rueidislimiter.Result + wantErr bool + setupMock bool + }{ + { + name: "negative tokens", + n: -1, + wantErr: true, + }, + { + name: "success with default limit", + mockResp: mock.Result(mock.RedisArray( + mock.RedisInt64(1), + mock.RedisInt64(resetTime), + )), + n: 1, + setupMock: true, + wantResult: rueidislimiter.Result{ + Allowed: true, + Remaining: 9, + ResetAtMs: resetTime, + }, + }, + { + name: "success with custom limit", + mockResp: mock.Result(mock.RedisArray( + mock.RedisInt64(5), + mock.RedisInt64(resetTime), + )), + n: 1, + setupMock: true, + customOpt: func() *rueidislimiter.RateLimitOption { + opt := rueidislimiter.WithCustomRateLimit(20, time.Second*2) + return &opt + }(), + wantResult: rueidislimiter.Result{ + Allowed: true, + Remaining: 15, + ResetAtMs: resetTime, + }, + }, + { + name: "limit exceeded", + mockResp: mock.Result(mock.RedisArray( + mock.RedisInt64(11), + mock.RedisInt64(resetTime), + )), + n: 1, + setupMock: true, + wantResult: rueidislimiter.Result{ + Allowed: false, + Remaining: 0, + ResetAtMs: resetTime, + }, + }, + { + name: "redis error", + mockResp: mock.ErrorResult(errors.New("redis error")), + n: 1, + setupMock: true, + wantErr: true, + }, + { + name: "invalid response type", + mockResp: mock.Result(mock.RedisString("invalid")), + n: 1, + setupMock: true, + wantErr: true, + }, + { + name: "invalid array length", + mockResp: mock.Result(mock.RedisArray(mock.RedisInt64(1))), + n: 1, + setupMock: true, + wantErr: true, + }, + { + name: "invalid first element", + mockResp: mock.Result(mock.RedisArray( + mock.RedisString("invalid"), + mock.RedisInt64(1), + )), + n: 1, + setupMock: true, + wantErr: true, + }, + { + name: "invalid second element", + mockResp: mock.Result(mock.RedisArray( + mock.RedisInt64(1), + mock.RedisString("invalid"), + )), + n: 1, + setupMock: true, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + client := mock.NewClient(ctrl) + if tt.setupMock { + client.EXPECT().Do(gomock.Any(), gomock.Any()).Return(tt.mockResp).Times(1) + } + + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return client, nil + }, + Limit: 10, + Window: time.Second, + }) + if err != nil { + t.Fatal(err) + } + + var got rueidislimiter.Result + if tt.customOpt != nil { + got, err = limiter.AllowN(context.Background(), "test", tt.n, *tt.customOpt) + } else { + got, err = limiter.AllowN(context.Background(), "test", tt.n) + } + + if (err != nil) != tt.wantErr { + t.Fatalf("AllowN() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + return + } + + if got != tt.wantResult { + t.Fatalf("AllowN() = %+v, want %+v", got, tt.wantResult) + } + }) } - return client } -func TestRateLimiter(t *testing.T) { - client := setup(t) - t.Cleanup(client.Close) +func TestRateLimiter_Check(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() now := time.Now() - window := 100 * time.Millisecond + resetTime := now.Add(time.Second).UnixMilli() + + client := mock.NewClient(ctrl) + client.EXPECT().Do(gomock.Any(), gomock.Any()).Return(mock.Result(mock.RedisArray( + mock.RedisInt64(5), + mock.RedisInt64(resetTime), + ))).Times(1) + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { return client, nil }, - Limit: 3, - Window: window, + Limit: 10, + Window: time.Second, }) if err != nil { t.Fatal(err) } - t.Run("Check defaults", func(t *testing.T) { - limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ - ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { - return client, nil - }, - }) - if err != nil { - t.Fatal(err) - } - result, err := limiter.Check(context.Background(), randStr()) - if err != nil { - t.Fatal(err) - } - if !result.Allowed || result.Remaining != 1 || result.ResetAtMs < now.UnixMilli() { - t.Fatalf("Expected Allowed=true, Remaining=1, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) - } - }) + got, err := limiter.Check(context.Background(), "test") + if err != nil { + t.Fatalf("Check() error = %v", err) + } - t.Run("Check allowed within limit", func(t *testing.T) { - result, err := limiter.Check(context.Background(), randStr()) - if err != nil { - t.Fatal(err) - } - if !result.Allowed || result.Remaining != 3 || result.ResetAtMs < now.UnixMilli() { - t.Fatalf("Expected Allowed=true, Remaining=3, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) - } - }) + want := rueidislimiter.Result{ + Allowed: true, + Remaining: 5, + ResetAtMs: resetTime, + } + if got != want { + t.Fatalf("Check() = %+v, want %+v", got, want) + } +} - t.Run("Check denied after exceeding limit", func(t *testing.T) { - key := randStr() - generateLoad(t, limiter, key, 3) +func TestRateLimiter_Allow(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - result, err := limiter.Check(context.Background(), key) - if err != nil { - t.Fatal(err) - } - if result.Allowed || result.Remaining != 0 || result.ResetAtMs < now.UnixMilli() { - t.Fatalf("Expected Allowed=false, Remaining=0, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) - } - }) + now := time.Now() + resetTime := now.Add(time.Second).UnixMilli() - t.Run("Check allowed after window reset", func(t *testing.T) { - key := randStr() - generateLoad(t, limiter, key, 3) + client := mock.NewClient(ctrl) + client.EXPECT().Do(gomock.Any(), gomock.Any()).Return(mock.Result(mock.RedisArray( + mock.RedisInt64(1), + mock.RedisInt64(resetTime), + ))).Times(1) - // Sleep for slightly longer than window duration to ensure reset - time.Sleep(window * 2) - result, err := limiter.Check(context.Background(), key) - if err != nil { - t.Fatal(err) - } - if !result.Allowed || result.Remaining != 3 || result.ResetAtMs < now.UnixMilli() { - t.Fatalf("Expected Allowed=true, Remaining=3, ResetAt >= now after reset; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) - } + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return client, nil + }, + Limit: 10, + Window: time.Second, }) + if err != nil { + t.Fatal(err) + } - t.Run("Check allowed with limit option", func(t *testing.T) { - key := randStr() - generateLoad(t, limiter, key, 3) - - result, err := limiter.Check(context.Background(), key) - if err != nil { - t.Fatal(err) - } - if result.Allowed { - t.Fatalf("Expected Allowed=false; got Allowed=%v", result.Allowed) - } + got, err := limiter.Allow(context.Background(), "test") + if err != nil { + t.Fatalf("Allow() error = %v", err) + } - result, err = limiter.Check(context.Background(), key, rueidislimiter.WithCustomRateLimit(10, time.Millisecond*100)) - if err != nil { - t.Fatal(err) - } - if !result.Allowed || result.Remaining != 7 || result.ResetAtMs < now.UnixMilli() { - t.Fatalf("Expected Allowed=true, Remaining=7, ResetAt >= now after reset; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) - } - }) + want := rueidislimiter.Result{ + Allowed: true, + Remaining: 9, + ResetAtMs: resetTime, + } + if got != want { + t.Fatalf("Allow() = %+v, want %+v", got, want) + } +} - t.Run("AllowN defaults", func(t *testing.T) { - limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ - ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { - return client, nil - }, - }) - if err != nil { - t.Fatal(err) - } - result, err := limiter.AllowN(context.Background(), randStr(), 1) - if err != nil { - t.Fatal(err) - } - if !result.Allowed || result.Remaining != 0 || result.ResetAtMs < now.UnixMilli() { - t.Fatalf("Expected Allowed=true, Remaining=0, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) - } - }) +func TestRateLimiter_Limit(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - t.Run("AllowN with tokens within limit", func(t *testing.T) { - key := randStr() - result, err := limiter.AllowN(context.Background(), key, 1) - if err != nil { - t.Fatal(err) - } - if !result.Allowed || result.Remaining != 2 || result.ResetAtMs < now.UnixMilli() { - t.Fatalf("Expected Allowed=true, Remaining=2, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) - } + client := mock.NewClient(ctrl) + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return client, nil + }, + Limit: 42, + Window: time.Second, }) + if err != nil { + t.Fatal(err) + } - t.Run("AllowN denied after exceeding limit", func(t *testing.T) { - key := randStr() - generateLoad(t, limiter, key, 3) - - result, err := limiter.AllowN(context.Background(), key, 1) - if err != nil { - t.Fatal(err) - } - if result.Allowed || result.Remaining != 0 || result.ResetAtMs < now.UnixMilli() { - t.Fatalf("Expected Allowed=false, Remaining=0, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) - } - }) + if got := limiter.Limit(); got != 42 { + t.Fatalf("Limit() = %v, want %v", got, 42) + } +} - t.Run("AllowN with zero tokens", func(t *testing.T) { - key := randStr() - result, err := limiter.AllowN(context.Background(), key, 0) - if err != nil { - t.Fatal(err) - } - if !result.Allowed { - t.Fatalf("Expected Allowed=true when allowing zero tokens, but got false") - } - }) +func BenchmarkAllowN(b *testing.B) { + ctrl := gomock.NewController(b) + defer ctrl.Finish() - t.Run("AllowN with negative tokens", func(t *testing.T) { - key := randStr() - result, err := limiter.AllowN(context.Background(), key, -1) - if err == nil { - t.Fatalf("Expected error for negative tokens, but got nil") - } - if result.Allowed { - t.Fatalf("Expected Allowed=false when allowing negative tokens, but got true") - } - }) -} + now := time.Now() + resetTime := now.Add(time.Second).UnixMilli() -func BenchmarkRateLimiter(b *testing.B) { - client := setup(b) - defer client.Close() + client := mock.NewClient(ctrl) + client.EXPECT().Do(gomock.Any(), gomock.Any()).Return(mock.Result(mock.RedisArray( + mock.RedisInt64(1), + mock.RedisInt64(resetTime), + ))).Times(b.N) limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { return client, nil }, + Limit: 1000, + Window: time.Second, }) if err != nil { b.Fatal(err) } - key := randStr() b.ResetTimer() - b.ReportAllocs() - - b.Run("Check", func(b *testing.B) { - for i := 0; i < b.N; i++ { - limiter.Check(context.Background(), key) - } - }) - - b.Run("AllowN", func(b *testing.B) { - for i := 0; i < b.N; i++ { - limiter.AllowN(context.Background(), key, 1) - } - }) -} - -func generateLoad(t *testing.T, limiter rueidislimiter.RateLimiterClient, key string, n int) { - for i := 0; i < n; i++ { - _, err := limiter.Allow(context.Background(), key) + for i := 0; i < b.N; i++ { + _, err := limiter.AllowN(context.Background(), "test", 1) if err != nil { - t.Fatal(err) + b.Fatal(err) } } } - -// randStr generates a 24-byte long, random string. -func randStr() string { - b := make([]byte, 24) - binary.LittleEndian.PutUint64(b[12:], rand.Uint64()) - binary.LittleEndian.PutUint32(b[20:], rand.Uint32()) - hex.Encode(b, b[12:]) - - return unsafe.String(unsafe.SliceData(b), len(b)) -} diff --git a/rueidislimiter/syncp.go b/rueidislimiter/syncp.go new file mode 100644 index 00000000..d3cd1de3 --- /dev/null +++ b/rueidislimiter/syncp.go @@ -0,0 +1,21 @@ +package rueidislimiter + +import "github.com/redis/rueidis/internal/util" + +var rateBuffersPool = util.NewPool(func(capacity int) *rateBuffersContainer { + return &rateBuffersContainer{ + keyBuf: make([]byte, 0, capacity), + } +}) + +type rateBuffersContainer struct { + keyBuf []byte +} + +func (r *rateBuffersContainer) Capacity() int { + return cap(r.keyBuf) +} + +func (r *rateBuffersContainer) ResetLen(n int) { + r.keyBuf = r.keyBuf[:0] +}