Skip to content

Commit f37208f

Browse files
committed
feat: explicitly query fallback servers for NS records
Fixes #29
1 parent 71668b8 commit f37208f

File tree

1 file changed

+118
-36
lines changed

1 file changed

+118
-36
lines changed

internal/dns/dns.go

Lines changed: 118 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,28 +48,18 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
4848
}
4949
}
5050

51-
// Split query name into labels and lookup each domain and parent until we get a hit
52-
queryLabels := dns.SplitDomainName(r.Question[0].Name)
53-
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
54-
lookupDomainName := strings.Join(queryLabels[startLabelIdx:], ".")
55-
nameServers, err := state.GetState().LookupDomain(lookupDomainName)
56-
if err != nil {
57-
logger.Errorf("failed to lookup domain: %s", err)
58-
}
59-
if nameServers == nil {
60-
continue
61-
}
51+
nameserverDomain, nameservers, err := findNameserversForDomain(r.Question[0].Name)
52+
if err != nil {
53+
logger.Errorf("failed to lookup nameservers for %s: %s", r.Question[0].Name, err)
54+
}
55+
if nameservers != nil {
6256
// Assemble response
6357
m.SetReply(r)
6458
if cfg.Dns.RecursionEnabled {
6559
// Pick random nameserver for domain
66-
tmpNameservers := []string{}
67-
for nameserver := range nameServers {
68-
tmpNameservers = append(tmpNameservers, nameserver)
69-
}
70-
tmpNameserver := nameServers[tmpNameservers[rand.Intn(len(tmpNameservers))]]
60+
tmpNameserver := randomNameserverAddress(nameservers)
7161
// Query the random domain nameserver we picked above
72-
resp, err := queryServer(r, tmpNameserver)
62+
resp, err := queryServer(r, tmpNameserver.String())
7363
if err != nil {
7464
// Send failure response
7565
m.SetRcode(r, dns.RcodeServerFailure)
@@ -87,31 +77,30 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
8777
return
8878
}
8979
} else {
90-
for nameserver, ipAddress := range nameServers {
91-
// Add trailing dot to make everybody happy
92-
nameserver = nameserver + `.`
80+
for nameserver, addresses := range nameservers {
9381
// NS record
9482
ns := &dns.NS{
95-
Hdr: dns.RR_Header{Name: (lookupDomainName + `.`), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 999},
83+
Hdr: dns.RR_Header{Name: (nameserverDomain), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 999},
9684
Ns: nameserver,
9785
}
9886
m.Ns = append(m.Ns, ns)
99-
// A or AAAA record
100-
ipAddr := net.ParseIP(ipAddress)
101-
if ipAddr.To4() != nil {
102-
// IPv4
103-
a := &dns.A{
104-
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 999},
105-
A: ipAddr,
106-
}
107-
m.Extra = append(m.Extra, a)
108-
} else {
109-
// IPv6
110-
aaaa := &dns.AAAA{
111-
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999},
112-
AAAA: ipAddr,
87+
for _, address := range addresses {
88+
// A or AAAA record
89+
if address.To4() != nil {
90+
// IPv4
91+
a := &dns.A{
92+
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 999},
93+
A: address,
94+
}
95+
m.Extra = append(m.Extra, a)
96+
} else {
97+
// IPv6
98+
aaaa := &dns.AAAA{
99+
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999},
100+
AAAA: address,
101+
}
102+
m.Extra = append(m.Extra, aaaa)
113103
}
114-
m.Extra = append(m.Extra, aaaa)
115104
}
116105
}
117106
}
@@ -178,3 +167,96 @@ func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) {
178167
in, err := dns.Exchange(m, fmt.Sprintf("%s:53", nameserver))
179168
return in, err
180169
}
170+
171+
func randomNameserverAddress(nameservers map[string][]net.IP) net.IP {
172+
// Put all namserver addresses in single list
173+
tmpNameservers := []net.IP{}
174+
for _, addresses := range nameservers {
175+
for _, address := range addresses {
176+
tmpNameservers = append(tmpNameservers, address)
177+
}
178+
}
179+
tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))]
180+
return tmpNameserver
181+
}
182+
183+
func doQuery(msg *dns.Msg, address string) (*dns.Msg, error) {
184+
logger := logging.GetLogger()
185+
logger.Debugf("querying %s: %s %s", address, dns.Type(msg.Question[0].Qtype).String(), msg.Question[0].Name)
186+
resp, err := dns.Exchange(msg, address)
187+
return resp, err
188+
}
189+
190+
func findNameserversForDomain(recordName string) (string, map[string][]net.IP, error) {
191+
cfg := config.GetConfig()
192+
193+
// Split record name into labels and lookup each domain and parent until we get a hit
194+
queryLabels := dns.SplitDomainName(recordName)
195+
196+
// Check on-chain domains first
197+
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
198+
lookupDomainName := strings.Join(queryLabels[startLabelIdx:], ".")
199+
nameservers, err := state.GetState().LookupDomain(lookupDomainName)
200+
if err != nil {
201+
return "", nil, err
202+
}
203+
if nameservers != nil {
204+
ret := map[string][]net.IP{}
205+
for k, v := range nameservers {
206+
k = k + `.`
207+
ret[k] = append(ret[k], net.ParseIP(v))
208+
}
209+
return dns.Fqdn(lookupDomainName), ret, nil
210+
}
211+
}
212+
213+
// Query fallback servers, if configured
214+
if len(cfg.Dns.FallbackServers) > 0 {
215+
// Pick random fallback server
216+
fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))]
217+
serverWithPort := fmt.Sprintf("%s:53", fallbackServer)
218+
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
219+
lookupDomainName := dns.Fqdn(strings.Join(queryLabels[startLabelIdx:], "."))
220+
m := new(dns.Msg)
221+
m.SetQuestion(lookupDomainName, dns.TypeNS)
222+
m.RecursionDesired = false
223+
in, err := doQuery(m, serverWithPort)
224+
if err != nil {
225+
return "", nil, err
226+
}
227+
if in.Rcode == dns.RcodeSuccess {
228+
if len(in.Answer) > 0 {
229+
ret := map[string][]net.IP{}
230+
for _, answer := range in.Answer {
231+
switch v := answer.(type) {
232+
case *dns.NS:
233+
ns := v.Ns
234+
ret[ns] = make([]net.IP, 0)
235+
// Query for matching A/AAAA records
236+
m2 := new(dns.Msg)
237+
m2.SetQuestion(ns, dns.TypeA)
238+
m2.RecursionDesired = false
239+
in2, err := doQuery(m2, serverWithPort)
240+
if err != nil {
241+
return "", nil, err
242+
}
243+
for _, answer2 := range in2.Answer {
244+
switch v := answer2.(type) {
245+
case *dns.A:
246+
ret[ns] = append(ret[ns], v.A)
247+
case *dns.AAAA:
248+
ret[ns] = append(ret[ns], v.AAAA)
249+
}
250+
}
251+
}
252+
}
253+
if len(ret) > 0 {
254+
return lookupDomainName, ret, nil
255+
}
256+
}
257+
}
258+
}
259+
}
260+
261+
return "", nil, nil
262+
}

0 commit comments

Comments
 (0)