Skip to content

Commit 01bec93

Browse files
authored
Merge pull request #3 from libp2p/feat/resolver-cache
Add DNS RR cache
2 parents 73c190e + 9b30936 commit 01bec93

File tree

3 files changed

+189
-17
lines changed

3 files changed

+189
-17
lines changed

request.go

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,74 +62,88 @@ func doRequest(ctx context.Context, url string, m *dns.Msg) (*dns.Msg, error) {
6262
return r, nil
6363
}
6464

65-
func doRequestA(ctx context.Context, url string, domain string) ([]net.IPAddr, error) {
65+
func doRequestA(ctx context.Context, url string, domain string) ([]net.IPAddr, uint32, error) {
6666
fqdn := dns.Fqdn(domain)
6767

6868
m := new(dns.Msg)
6969
m.SetQuestion(fqdn, dns.TypeA)
7070

7171
r, err := doRequest(ctx, url, m)
7272
if err != nil {
73-
return nil, err
73+
return nil, 0, err
7474
}
7575

76+
var ttl uint32
7677
result := make([]net.IPAddr, 0, len(r.Answer))
7778
for _, rr := range r.Answer {
7879
switch v := rr.(type) {
7980
case *dns.A:
8081
result = append(result, net.IPAddr{IP: v.A})
82+
if ttl == 0 || v.Hdr.Ttl < ttl {
83+
ttl = v.Hdr.Ttl
84+
}
8185
default:
8286
log.Warnf("unexpected DNS resource record %+v", rr)
8387
}
8488
}
8589

86-
return result, nil
90+
return result, ttl, nil
8791
}
8892

89-
func doRequestAAAA(ctx context.Context, url string, domain string) ([]net.IPAddr, error) {
93+
func doRequestAAAA(ctx context.Context, url string, domain string) ([]net.IPAddr, uint32, error) {
9094
fqdn := dns.Fqdn(domain)
9195

9296
m := new(dns.Msg)
9397
m.SetQuestion(fqdn, dns.TypeAAAA)
9498

9599
r, err := doRequest(ctx, url, m)
96100
if err != nil {
97-
return nil, err
101+
return nil, 0, err
98102
}
99103

104+
var ttl uint32
100105
result := make([]net.IPAddr, 0, len(r.Answer))
101106
for _, rr := range r.Answer {
102107
switch v := rr.(type) {
103108
case *dns.AAAA:
104109
result = append(result, net.IPAddr{IP: v.AAAA})
110+
if ttl == 0 || v.Hdr.Ttl < ttl {
111+
ttl = v.Hdr.Ttl
112+
}
113+
105114
default:
106115
log.Warnf("unexpected DNS resource record %+v", rr)
107116
}
108117
}
109118

110-
return result, nil
119+
return result, ttl, nil
111120
}
112121

113-
func doRequestTXT(ctx context.Context, url string, domain string) ([]string, error) {
122+
func doRequestTXT(ctx context.Context, url string, domain string) ([]string, uint32, error) {
114123
fqdn := dns.Fqdn(domain)
115124

116125
m := new(dns.Msg)
117126
m.SetQuestion(fqdn, dns.TypeTXT)
118127

119128
r, err := doRequest(ctx, url, m)
120129
if err != nil {
121-
return nil, err
130+
return nil, 0, err
122131
}
123132

133+
var ttl uint32
124134
var result []string
125135
for _, rr := range r.Answer {
126136
switch v := rr.(type) {
127137
case *dns.TXT:
128138
result = append(result, v.Txt...)
139+
if ttl == 0 || v.Hdr.Ttl < ttl {
140+
ttl = v.Hdr.Ttl
141+
}
142+
129143
default:
130144
log.Warnf("unexpected DNS resource record %+v", rr)
131145
}
132146
}
133147

134-
return result, nil
148+
return result, ttl, nil
135149
}

resolver.go

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,53 +4,158 @@ import (
44
"context"
55
"net"
66
"strings"
7+
"sync"
8+
"time"
9+
10+
"github.com/miekg/dns"
711

812
madns "github.com/multiformats/go-multiaddr-dns"
913
)
1014

1115
type Resolver struct {
16+
sync.RWMutex
1217
url string
18+
19+
// RR cache
20+
ipCache map[string]ipAddrEntry
21+
txtCache map[string]txtEntry
22+
}
23+
24+
type ipAddrEntry struct {
25+
ips []net.IPAddr
26+
expire time.Time
27+
}
28+
29+
type txtEntry struct {
30+
txt []string
31+
expire time.Time
1332
}
1433

1534
func NewResolver(url string) *Resolver {
1635
if !strings.HasPrefix(url, "https://") {
1736
url = "https://" + url
1837
}
1938

20-
return &Resolver{url: url}
39+
return &Resolver{
40+
url: url,
41+
ipCache: make(map[string]ipAddrEntry),
42+
txtCache: make(map[string]txtEntry),
43+
}
2144
}
2245

2346
var _ madns.BasicResolver = (*Resolver)(nil)
2447

2548
func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) (result []net.IPAddr, err error) {
49+
result, ok := r.getCachedIPAddr(domain)
50+
if ok {
51+
return result, nil
52+
}
53+
2654
type response struct {
2755
ips []net.IPAddr
56+
ttl uint32
2857
err error
2958
}
3059

3160
resch := make(chan response, 2)
3261
go func() {
33-
ip4, err := doRequestA(ctx, r.url, domain)
34-
resch <- response{ip4, err}
62+
ip4, ttl, err := doRequestA(ctx, r.url, domain)
63+
resch <- response{ip4, ttl, err}
3564
}()
3665

3766
go func() {
38-
ip6, err := doRequestAAAA(ctx, r.url, domain)
39-
resch <- response{ip6, err}
67+
ip6, ttl, err := doRequestAAAA(ctx, r.url, domain)
68+
resch <- response{ip6, ttl, err}
4069
}()
4170

71+
var ttl uint32
4272
for i := 0; i < 2; i++ {
4373
r := <-resch
4474
if r.err != nil {
4575
return nil, r.err
4676
}
4777

4878
result = append(result, r.ips...)
79+
if ttl == 0 || r.ttl < ttl {
80+
ttl = r.ttl
81+
}
4982
}
5083

84+
r.cacheIPAddr(domain, result, ttl)
5185
return result, nil
5286
}
5387

5488
func (r *Resolver) LookupTXT(ctx context.Context, domain string) ([]string, error) {
55-
return doRequestTXT(ctx, r.url, domain)
89+
result, ok := r.getCachedTXT(domain)
90+
if ok {
91+
return result, nil
92+
}
93+
94+
result, ttl, err := doRequestTXT(ctx, r.url, domain)
95+
if err != nil {
96+
return nil, err
97+
}
98+
99+
r.cacheTXT(domain, result, ttl)
100+
return result, nil
101+
}
102+
103+
func (r *Resolver) getCachedIPAddr(domain string) ([]net.IPAddr, bool) {
104+
r.RLock()
105+
defer r.RUnlock()
106+
107+
fqdn := dns.Fqdn(domain)
108+
entry, ok := r.ipCache[fqdn]
109+
if !ok {
110+
return nil, false
111+
}
112+
113+
if time.Now().After(entry.expire) {
114+
delete(r.ipCache, fqdn)
115+
return nil, false
116+
}
117+
118+
return entry.ips, true
119+
}
120+
121+
func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl uint32) {
122+
if ttl == 0 {
123+
return
124+
}
125+
126+
r.Lock()
127+
defer r.Unlock()
128+
129+
fqdn := dns.Fqdn(domain)
130+
r.ipCache[fqdn] = ipAddrEntry{ips, time.Now().Add(time.Duration(ttl) * time.Second)}
131+
}
132+
133+
func (r *Resolver) getCachedTXT(domain string) ([]string, bool) {
134+
r.RLock()
135+
defer r.RUnlock()
136+
137+
fqdn := dns.Fqdn(domain)
138+
entry, ok := r.txtCache[fqdn]
139+
if !ok {
140+
return nil, false
141+
}
142+
143+
if time.Now().After(entry.expire) {
144+
delete(r.txtCache, fqdn)
145+
return nil, false
146+
}
147+
148+
return entry.txt, true
149+
}
150+
151+
func (r *Resolver) cacheTXT(domain string, txt []string, ttl uint32) {
152+
if ttl == 0 {
153+
return
154+
}
155+
156+
r.Lock()
157+
defer r.Unlock()
158+
159+
fqdn := dns.Fqdn(domain)
160+
r.txtCache[fqdn] = txtEntry{txt, time.Now().Add(time.Duration(ttl) * time.Second)}
56161
}

resolver_test.go

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
package doh
22

33
import (
4+
"bytes"
45
"context"
6+
"net"
57
"testing"
68
)
79

810
func TestLookupIPAddr(t *testing.T) {
911
r := NewResolver("https://cloudflare-dns.com/dns-query")
1012

11-
ips, err := r.LookupIPAddr(context.Background(), "libp2p.io")
13+
domain := "libp2p.io"
14+
ips, err := r.LookupIPAddr(context.Background(), domain)
1215
if err != nil {
1316
t.Fatal(err)
1417
}
1518
if len(ips) == 0 {
1619
t.Fatal("got no IPs")
1720
}
21+
22+
// check that we got both IPv4 and IPv6 addrs
1823
var got4, got6 bool
1924
for _, ip := range ips {
2025
if len(ip.IP.To4()) == 4 {
@@ -29,16 +34,64 @@ func TestLookupIPAddr(t *testing.T) {
2934
if !got6 {
3035
t.Fatal("got no IPv6 addresses")
3136
}
37+
38+
// check the cache
39+
ips2, ok := r.getCachedIPAddr(domain)
40+
if !ok {
41+
t.Fatal("expected cache to be populated")
42+
}
43+
if !sameIPs(ips, ips2) {
44+
t.Fatal("expected cache to contain the same addrs")
45+
}
3246
}
3347

3448
func TestLookupTXT(t *testing.T) {
3549
r := NewResolver("https://cloudflare-dns.com/dns-query")
3650

37-
txt, err := r.LookupTXT(context.Background(), "_dnsaddr.bootstrap.libp2p.io")
51+
domain := "_dnsaddr.bootstrap.libp2p.io"
52+
txt, err := r.LookupTXT(context.Background(), domain)
3853
if err != nil {
3954
t.Fatal(err)
4055
}
4156
if len(txt) == 0 {
4257
t.Fatal("got no TXT entries")
4358
}
59+
60+
// check the cache
61+
txt2, ok := r.getCachedTXT(domain)
62+
if !ok {
63+
t.Fatal("expected cache to be populated")
64+
}
65+
if !sameTXT(txt, txt2) {
66+
t.Fatal("expected cache to contain the same txt entries")
67+
}
68+
69+
}
70+
71+
func sameIPs(a, b []net.IPAddr) bool {
72+
if len(a) != len(b) {
73+
return false
74+
}
75+
76+
for i := range a {
77+
if !bytes.Equal(a[i].IP, b[i].IP) {
78+
return false
79+
}
80+
}
81+
82+
return true
83+
}
84+
85+
func sameTXT(a, b []string) bool {
86+
if len(a) != len(b) {
87+
return false
88+
}
89+
90+
for i := range a {
91+
if a[i] != b[i] {
92+
return false
93+
}
94+
}
95+
96+
return true
4497
}

0 commit comments

Comments
 (0)