@@ -48,28 +48,18 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
48
48
}
49
49
}
50
50
51
- // Split query name into labels and lookup each domain and parent until we get a hit
52
- queryLabels := dns .SplitDomainName (r .Question [0 ].Name )
53
- for startLabelIdx := 0 ; startLabelIdx < len (queryLabels ); startLabelIdx ++ {
54
- lookupDomainName := strings .Join (queryLabels [startLabelIdx :], "." )
55
- nameServers , err := state .GetState ().LookupDomain (lookupDomainName )
56
- if err != nil {
57
- logger .Errorf ("failed to lookup domain: %s" , err )
58
- }
59
- if nameServers == nil {
60
- continue
61
- }
51
+ nameserverDomain , nameservers , err := findNameserversForDomain (r .Question [0 ].Name )
52
+ if err != nil {
53
+ logger .Errorf ("failed to lookup nameservers for %s: %s" , r .Question [0 ].Name , err )
54
+ }
55
+ if nameservers != nil {
62
56
// Assemble response
63
57
m .SetReply (r )
64
58
if cfg .Dns .RecursionEnabled {
65
59
// Pick random nameserver for domain
66
- tmpNameservers := []string {}
67
- for nameserver := range nameServers {
68
- tmpNameservers = append (tmpNameservers , nameserver )
69
- }
70
- tmpNameserver := nameServers [tmpNameservers [rand .Intn (len (tmpNameservers ))]]
60
+ tmpNameserver := randomNameserverAddress (nameservers )
71
61
// Query the random domain nameserver we picked above
72
- resp , err := queryServer (r , tmpNameserver )
62
+ resp , err := queryServer (r , tmpNameserver . String () )
73
63
if err != nil {
74
64
// Send failure response
75
65
m .SetRcode (r , dns .RcodeServerFailure )
@@ -87,31 +77,30 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
87
77
return
88
78
}
89
79
} else {
90
- for nameserver , ipAddress := range nameServers {
91
- // Add trailing dot to make everybody happy
92
- nameserver = nameserver + `.`
80
+ for nameserver , addresses := range nameservers {
93
81
// NS record
94
82
ns := & dns.NS {
95
- Hdr : dns.RR_Header {Name : (lookupDomainName + `.` ), Rrtype : dns .TypeNS , Class : dns .ClassINET , Ttl : 999 },
83
+ Hdr : dns.RR_Header {Name : (nameserverDomain ), Rrtype : dns .TypeNS , Class : dns .ClassINET , Ttl : 999 },
96
84
Ns : nameserver ,
97
85
}
98
86
m .Ns = append (m .Ns , ns )
99
- // A or AAAA record
100
- ipAddr := net .ParseIP (ipAddress )
101
- if ipAddr .To4 () != nil {
102
- // IPv4
103
- a := & dns.A {
104
- Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeA , Class : dns .ClassINET , Ttl : 999 },
105
- A : ipAddr ,
106
- }
107
- m .Extra = append (m .Extra , a )
108
- } else {
109
- // IPv6
110
- aaaa := & dns.AAAA {
111
- Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeAAAA , Class : dns .ClassINET , Ttl : 999 },
112
- AAAA : ipAddr ,
87
+ for _ , address := range addresses {
88
+ // A or AAAA record
89
+ if address .To4 () != nil {
90
+ // IPv4
91
+ a := & dns.A {
92
+ Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeA , Class : dns .ClassINET , Ttl : 999 },
93
+ A : address ,
94
+ }
95
+ m .Extra = append (m .Extra , a )
96
+ } else {
97
+ // IPv6
98
+ aaaa := & dns.AAAA {
99
+ Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeAAAA , Class : dns .ClassINET , Ttl : 999 },
100
+ AAAA : address ,
101
+ }
102
+ m .Extra = append (m .Extra , aaaa )
113
103
}
114
- m .Extra = append (m .Extra , aaaa )
115
104
}
116
105
}
117
106
}
@@ -178,3 +167,96 @@ func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) {
178
167
in , err := dns .Exchange (m , fmt .Sprintf ("%s:53" , nameserver ))
179
168
return in , err
180
169
}
170
+
171
+ func randomNameserverAddress (nameservers map [string ][]net.IP ) net.IP {
172
+ // Put all namserver addresses in single list
173
+ tmpNameservers := []net.IP {}
174
+ for _ , addresses := range nameservers {
175
+ for _ , address := range addresses {
176
+ tmpNameservers = append (tmpNameservers , address )
177
+ }
178
+ }
179
+ tmpNameserver := tmpNameservers [rand .Intn (len (tmpNameservers ))]
180
+ return tmpNameserver
181
+ }
182
+
183
+ func doQuery (msg * dns.Msg , address string ) (* dns.Msg , error ) {
184
+ logger := logging .GetLogger ()
185
+ logger .Debugf ("querying %s: %s %s" , address , dns .Type (msg .Question [0 ].Qtype ).String (), msg .Question [0 ].Name )
186
+ resp , err := dns .Exchange (msg , address )
187
+ return resp , err
188
+ }
189
+
190
+ func findNameserversForDomain (recordName string ) (string , map [string ][]net.IP , error ) {
191
+ cfg := config .GetConfig ()
192
+
193
+ // Split record name into labels and lookup each domain and parent until we get a hit
194
+ queryLabels := dns .SplitDomainName (recordName )
195
+
196
+ // Check on-chain domains first
197
+ for startLabelIdx := 0 ; startLabelIdx < len (queryLabels ); startLabelIdx ++ {
198
+ lookupDomainName := strings .Join (queryLabels [startLabelIdx :], "." )
199
+ nameservers , err := state .GetState ().LookupDomain (lookupDomainName )
200
+ if err != nil {
201
+ return "" , nil , err
202
+ }
203
+ if nameservers != nil {
204
+ ret := map [string ][]net.IP {}
205
+ for k , v := range nameservers {
206
+ k = k + `.`
207
+ ret [k ] = append (ret [k ], net .ParseIP (v ))
208
+ }
209
+ return dns .Fqdn (lookupDomainName ), ret , nil
210
+ }
211
+ }
212
+
213
+ // Query fallback servers, if configured
214
+ if len (cfg .Dns .FallbackServers ) > 0 {
215
+ // Pick random fallback server
216
+ fallbackServer := cfg .Dns .FallbackServers [rand .Intn (len (cfg .Dns .FallbackServers ))]
217
+ serverWithPort := fmt .Sprintf ("%s:53" , fallbackServer )
218
+ for startLabelIdx := 0 ; startLabelIdx < len (queryLabels ); startLabelIdx ++ {
219
+ lookupDomainName := dns .Fqdn (strings .Join (queryLabels [startLabelIdx :], "." ))
220
+ m := new (dns.Msg )
221
+ m .SetQuestion (lookupDomainName , dns .TypeNS )
222
+ m .RecursionDesired = false
223
+ in , err := doQuery (m , serverWithPort )
224
+ if err != nil {
225
+ return "" , nil , err
226
+ }
227
+ if in .Rcode == dns .RcodeSuccess {
228
+ if len (in .Answer ) > 0 {
229
+ ret := map [string ][]net.IP {}
230
+ for _ , answer := range in .Answer {
231
+ switch v := answer .(type ) {
232
+ case * dns.NS :
233
+ ns := v .Ns
234
+ ret [ns ] = make ([]net.IP , 0 )
235
+ // Query for matching A/AAAA records
236
+ m2 := new (dns.Msg )
237
+ m2 .SetQuestion (ns , dns .TypeA )
238
+ m2 .RecursionDesired = false
239
+ in2 , err := doQuery (m2 , serverWithPort )
240
+ if err != nil {
241
+ return "" , nil , err
242
+ }
243
+ for _ , answer2 := range in2 .Answer {
244
+ switch v := answer2 .(type ) {
245
+ case * dns.A :
246
+ ret [ns ] = append (ret [ns ], v .A )
247
+ case * dns.AAAA :
248
+ ret [ns ] = append (ret [ns ], v .AAAA )
249
+ }
250
+ }
251
+ }
252
+ }
253
+ if len (ret ) > 0 {
254
+ return lookupDomainName , ret , nil
255
+ }
256
+ }
257
+ }
258
+ }
259
+ }
260
+
261
+ return "" , nil , nil
262
+ }
0 commit comments