Skip to content

Commit b2485e6

Browse files
authored
Merge pull request #1433 from apernet/fix-acl-cache
fix(acl): error rule applied due to bad cache key
2 parents a695a9c + 5ef8569 commit b2485e6

File tree

2 files changed

+57
-6
lines changed

2 files changed

+57
-6
lines changed

extras/outbounds/acl/compile.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,19 @@ const (
1919
ProtocolUDP
2020
)
2121

22+
func (p Protocol) String() string {
23+
switch p {
24+
case ProtocolBoth:
25+
return "tcp+udp"
26+
case ProtocolTCP:
27+
return "tcp"
28+
case ProtocolUDP:
29+
return "udp"
30+
default:
31+
return fmt.Sprintf("Protocol(%d)", int(p))
32+
}
33+
}
34+
2235
type Outbound interface {
2336
any
2437
}
@@ -63,12 +76,22 @@ type matchResult[O Outbound] struct {
6376

6477
type compiledRuleSetImpl[O Outbound] struct {
6578
Rules []compiledRule[O]
66-
Cache *lru.Cache[string, matchResult[O]] // key: HostInfo.String()
79+
Cache *lru.Cache[matchResultCacheKey, matchResult[O]] // key: HostInfo.String()
80+
}
81+
82+
type matchResultCacheKey struct {
83+
Host string
84+
Proto Protocol
85+
Port uint16
6786
}
6887

6988
func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto Protocol, port uint16) (O, net.IP) {
7089
host.Name = strings.ToLower(host.Name) // Normalize host name to lower case
71-
key := host.String()
90+
key := matchResultCacheKey{
91+
Host: host.String(),
92+
Proto: proto,
93+
Port: port,
94+
}
7295
if result, ok := s.Cache.Get(key); ok {
7396
return result.Outbound, result.HijackAddress
7497
}
@@ -130,7 +153,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
130153
}
131154
compiledRules[i] = compiledRule[O]{outbound, hm, proto, startPort, endPort, hijackAddress}
132155
}
133-
cache, err := lru.New[string, matchResult[O]](cacheSize)
156+
cache, err := lru.New[matchResultCacheKey, matchResult[O]](cacheSize)
134157
if err != nil {
135158
return nil, err
136159
}

extras/outbounds/acl/compile_test.go

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package acl
22

33
import (
4+
"fmt"
45
"net"
56
"testing"
67

@@ -177,6 +178,30 @@ func TestCompile(t *testing.T) {
177178
wantOutbound: ob1,
178179
wantIP: net.ParseIP("2.2.2.2"),
179180
},
181+
{
182+
host: HostInfo{
183+
Name: "crap.v2ex.com",
184+
},
185+
proto: ProtocolTCP,
186+
port: 81,
187+
wantOutbound: 0,
188+
},
189+
{
190+
host: HostInfo{
191+
Name: "crap.v2ex.com",
192+
},
193+
proto: ProtocolUDP,
194+
port: 80,
195+
wantOutbound: ob3,
196+
},
197+
{
198+
host: HostInfo{
199+
Name: "crap.v2ex.com",
200+
},
201+
proto: ProtocolUDP,
202+
port: 81,
203+
wantOutbound: ob3,
204+
},
180205
{
181206
host: HostInfo{
182207
IPv4: net.ParseIP("210.140.92.187"),
@@ -261,9 +286,12 @@ func TestCompile(t *testing.T) {
261286
}
262287

263288
for _, test := range tests {
264-
gotOutbound, gotIP := comp.Match(test.host, test.proto, test.port)
265-
assert.Equal(t, test.wantOutbound, gotOutbound)
266-
assert.Equal(t, test.wantIP, gotIP)
289+
testName := fmt.Sprintf("%s#%s#%d", test.host, test.proto, test.port)
290+
t.Run(testName, func(t *testing.T) {
291+
gotOutbound, gotIP := comp.Match(test.host, test.proto, test.port)
292+
assert.Equal(t, test.wantOutbound, gotOutbound)
293+
assert.Equal(t, test.wantIP, gotIP)
294+
})
267295
}
268296

269297
// Test Invalid Port Range Rule

0 commit comments

Comments
 (0)