Skip to content

Commit c5b2bae

Browse files
authored
Merge pull request #45 from blinklabs-io/feat/fallback-query-ns
feat: explicitly query fallback servers for NS records
2 parents 71668b8 + 6530134 commit c5b2bae

File tree

1 file changed

+116
-36
lines changed

1 file changed

+116
-36
lines changed

internal/dns/dns.go

Lines changed: 116 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,28 +48,18 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
4848
}
4949
}
5050

51-
// Split query name into labels and lookup each domain and parent until we get a hit
52-
queryLabels := dns.SplitDomainName(r.Question[0].Name)
53-
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
54-
lookupDomainName := strings.Join(queryLabels[startLabelIdx:], ".")
55-
nameServers, err := state.GetState().LookupDomain(lookupDomainName)
56-
if err != nil {
57-
logger.Errorf("failed to lookup domain: %s", err)
58-
}
59-
if nameServers == nil {
60-
continue
61-
}
51+
nameserverDomain, nameservers, err := findNameserversForDomain(r.Question[0].Name)
52+
if err != nil {
53+
logger.Errorf("failed to lookup nameservers for %s: %s", r.Question[0].Name, err)
54+
}
55+
if nameservers != nil {
6256
// Assemble response
6357
m.SetReply(r)
6458
if cfg.Dns.RecursionEnabled {
6559
// Pick random nameserver for domain
66-
tmpNameservers := []string{}
67-
for nameserver := range nameServers {
68-
tmpNameservers = append(tmpNameservers, nameserver)
69-
}
70-
tmpNameserver := nameServers[tmpNameservers[rand.Intn(len(tmpNameservers))]]
60+
tmpNameserver := randomNameserverAddress(nameservers)
7161
// Query the random domain nameserver we picked above
72-
resp, err := queryServer(r, tmpNameserver)
62+
resp, err := queryServer(r, tmpNameserver.String())
7363
if err != nil {
7464
// Send failure response
7565
m.SetRcode(r, dns.RcodeServerFailure)
@@ -87,31 +77,30 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
8777
return
8878
}
8979
} else {
90-
for nameserver, ipAddress := range nameServers {
91-
// Add trailing dot to make everybody happy
92-
nameserver = nameserver + `.`
80+
for nameserver, addresses := range nameservers {
9381
// NS record
9482
ns := &dns.NS{
95-
Hdr: dns.RR_Header{Name: (lookupDomainName + `.`), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 999},
83+
Hdr: dns.RR_Header{Name: (nameserverDomain), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 999},
9684
Ns: nameserver,
9785
}
9886
m.Ns = append(m.Ns, ns)
99-
// A or AAAA record
100-
ipAddr := net.ParseIP(ipAddress)
101-
if ipAddr.To4() != nil {
102-
// IPv4
103-
a := &dns.A{
104-
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 999},
105-
A: ipAddr,
106-
}
107-
m.Extra = append(m.Extra, a)
108-
} else {
109-
// IPv6
110-
aaaa := &dns.AAAA{
111-
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999},
112-
AAAA: ipAddr,
87+
for _, address := range addresses {
88+
// A or AAAA record
89+
if address.To4() != nil {
90+
// IPv4
91+
a := &dns.A{
92+
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 999},
93+
A: address,
94+
}
95+
m.Extra = append(m.Extra, a)
96+
} else {
97+
// IPv6
98+
aaaa := &dns.AAAA{
99+
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999},
100+
AAAA: address,
101+
}
102+
m.Extra = append(m.Extra, aaaa)
113103
}
114-
m.Extra = append(m.Extra, aaaa)
115104
}
116105
}
117106
}
@@ -178,3 +167,94 @@ func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) {
178167
in, err := dns.Exchange(m, fmt.Sprintf("%s:53", nameserver))
179168
return in, err
180169
}
170+
171+
func randomNameserverAddress(nameservers map[string][]net.IP) net.IP {
172+
// Put all namserver addresses in single list
173+
tmpNameservers := []net.IP{}
174+
for _, addresses := range nameservers {
175+
tmpNameservers = append(tmpNameservers, addresses...)
176+
}
177+
tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))]
178+
return tmpNameserver
179+
}
180+
181+
func doQuery(msg *dns.Msg, address string) (*dns.Msg, error) {
182+
logger := logging.GetLogger()
183+
logger.Debugf("querying %s: %s %s", address, dns.Type(msg.Question[0].Qtype).String(), msg.Question[0].Name)
184+
resp, err := dns.Exchange(msg, address)
185+
return resp, err
186+
}
187+
188+
func findNameserversForDomain(recordName string) (string, map[string][]net.IP, error) {
189+
cfg := config.GetConfig()
190+
191+
// Split record name into labels and lookup each domain and parent until we get a hit
192+
queryLabels := dns.SplitDomainName(recordName)
193+
194+
// Check on-chain domains first
195+
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
196+
lookupDomainName := strings.Join(queryLabels[startLabelIdx:], ".")
197+
nameservers, err := state.GetState().LookupDomain(lookupDomainName)
198+
if err != nil {
199+
return "", nil, err
200+
}
201+
if nameservers != nil {
202+
ret := map[string][]net.IP{}
203+
for k, v := range nameservers {
204+
k = k + `.`
205+
ret[k] = append(ret[k], net.ParseIP(v))
206+
}
207+
return dns.Fqdn(lookupDomainName), ret, nil
208+
}
209+
}
210+
211+
// Query fallback servers, if configured
212+
if len(cfg.Dns.FallbackServers) > 0 {
213+
// Pick random fallback server
214+
fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))]
215+
serverWithPort := fmt.Sprintf("%s:53", fallbackServer)
216+
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
217+
lookupDomainName := dns.Fqdn(strings.Join(queryLabels[startLabelIdx:], "."))
218+
m := new(dns.Msg)
219+
m.SetQuestion(lookupDomainName, dns.TypeNS)
220+
m.RecursionDesired = false
221+
in, err := doQuery(m, serverWithPort)
222+
if err != nil {
223+
return "", nil, err
224+
}
225+
if in.Rcode == dns.RcodeSuccess {
226+
if len(in.Answer) > 0 {
227+
ret := map[string][]net.IP{}
228+
for _, answer := range in.Answer {
229+
switch v := answer.(type) {
230+
case *dns.NS:
231+
ns := v.Ns
232+
ret[ns] = make([]net.IP, 0)
233+
// Query for matching A/AAAA records
234+
m2 := new(dns.Msg)
235+
m2.SetQuestion(ns, dns.TypeA)
236+
m2.RecursionDesired = false
237+
in2, err := doQuery(m2, serverWithPort)
238+
if err != nil {
239+
return "", nil, err
240+
}
241+
for _, answer2 := range in2.Answer {
242+
switch v := answer2.(type) {
243+
case *dns.A:
244+
ret[ns] = append(ret[ns], v.A)
245+
case *dns.AAAA:
246+
ret[ns] = append(ret[ns], v.AAAA)
247+
}
248+
}
249+
}
250+
}
251+
if len(ret) > 0 {
252+
return lookupDomainName, ret, nil
253+
}
254+
}
255+
}
256+
}
257+
}
258+
259+
return "", nil, nil
260+
}

0 commit comments

Comments
 (0)