@@ -48,28 +48,18 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
48
48
}
49
49
}
50
50
51
- // Split query name into labels and lookup each domain and parent until we get a hit
52
- queryLabels := dns .SplitDomainName (r .Question [0 ].Name )
53
- for startLabelIdx := 0 ; startLabelIdx < len (queryLabels ); startLabelIdx ++ {
54
- lookupDomainName := strings .Join (queryLabels [startLabelIdx :], "." )
55
- nameServers , err := state .GetState ().LookupDomain (lookupDomainName )
56
- if err != nil {
57
- logger .Errorf ("failed to lookup domain: %s" , err )
58
- }
59
- if nameServers == nil {
60
- continue
61
- }
51
+ nameserverDomain , nameservers , err := findNameserversForDomain (r .Question [0 ].Name )
52
+ if err != nil {
53
+ logger .Errorf ("failed to lookup nameservers for %s: %s" , r .Question [0 ].Name , err )
54
+ }
55
+ if nameservers != nil {
62
56
// Assemble response
63
57
m .SetReply (r )
64
58
if cfg .Dns .RecursionEnabled {
65
59
// Pick random nameserver for domain
66
- tmpNameservers := []string {}
67
- for nameserver := range nameServers {
68
- tmpNameservers = append (tmpNameservers , nameserver )
69
- }
70
- tmpNameserver := nameServers [tmpNameservers [rand .Intn (len (tmpNameservers ))]]
60
+ tmpNameserver := randomNameserverAddress (nameservers )
71
61
// Query the random domain nameserver we picked above
72
- resp , err := queryServer (r , tmpNameserver )
62
+ resp , err := queryServer (r , tmpNameserver . String () )
73
63
if err != nil {
74
64
// Send failure response
75
65
m .SetRcode (r , dns .RcodeServerFailure )
@@ -87,31 +77,30 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
87
77
return
88
78
}
89
79
} else {
90
- for nameserver , ipAddress := range nameServers {
91
- // Add trailing dot to make everybody happy
92
- nameserver = nameserver + `.`
80
+ for nameserver , addresses := range nameservers {
93
81
// NS record
94
82
ns := & dns.NS {
95
- Hdr : dns.RR_Header {Name : (lookupDomainName + `.` ), Rrtype : dns .TypeNS , Class : dns .ClassINET , Ttl : 999 },
83
+ Hdr : dns.RR_Header {Name : (nameserverDomain ), Rrtype : dns .TypeNS , Class : dns .ClassINET , Ttl : 999 },
96
84
Ns : nameserver ,
97
85
}
98
86
m .Ns = append (m .Ns , ns )
99
- // A or AAAA record
100
- ipAddr := net .ParseIP (ipAddress )
101
- if ipAddr .To4 () != nil {
102
- // IPv4
103
- a := & dns.A {
104
- Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeA , Class : dns .ClassINET , Ttl : 999 },
105
- A : ipAddr ,
106
- }
107
- m .Extra = append (m .Extra , a )
108
- } else {
109
- // IPv6
110
- aaaa := & dns.AAAA {
111
- Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeAAAA , Class : dns .ClassINET , Ttl : 999 },
112
- AAAA : ipAddr ,
87
+ for _ , address := range addresses {
88
+ // A or AAAA record
89
+ if address .To4 () != nil {
90
+ // IPv4
91
+ a := & dns.A {
92
+ Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeA , Class : dns .ClassINET , Ttl : 999 },
93
+ A : address ,
94
+ }
95
+ m .Extra = append (m .Extra , a )
96
+ } else {
97
+ // IPv6
98
+ aaaa := & dns.AAAA {
99
+ Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeAAAA , Class : dns .ClassINET , Ttl : 999 },
100
+ AAAA : address ,
101
+ }
102
+ m .Extra = append (m .Extra , aaaa )
113
103
}
114
- m .Extra = append (m .Extra , aaaa )
115
104
}
116
105
}
117
106
}
@@ -178,3 +167,94 @@ func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) {
178
167
in , err := dns .Exchange (m , fmt .Sprintf ("%s:53" , nameserver ))
179
168
return in , err
180
169
}
170
+
171
+ func randomNameserverAddress (nameservers map [string ][]net.IP ) net.IP {
172
+ // Put all namserver addresses in single list
173
+ tmpNameservers := []net.IP {}
174
+ for _ , addresses := range nameservers {
175
+ tmpNameservers = append (tmpNameservers , addresses ... )
176
+ }
177
+ tmpNameserver := tmpNameservers [rand .Intn (len (tmpNameservers ))]
178
+ return tmpNameserver
179
+ }
180
+
181
+ func doQuery (msg * dns.Msg , address string ) (* dns.Msg , error ) {
182
+ logger := logging .GetLogger ()
183
+ logger .Debugf ("querying %s: %s %s" , address , dns .Type (msg .Question [0 ].Qtype ).String (), msg .Question [0 ].Name )
184
+ resp , err := dns .Exchange (msg , address )
185
+ return resp , err
186
+ }
187
+
188
+ func findNameserversForDomain (recordName string ) (string , map [string ][]net.IP , error ) {
189
+ cfg := config .GetConfig ()
190
+
191
+ // Split record name into labels and lookup each domain and parent until we get a hit
192
+ queryLabels := dns .SplitDomainName (recordName )
193
+
194
+ // Check on-chain domains first
195
+ for startLabelIdx := 0 ; startLabelIdx < len (queryLabels ); startLabelIdx ++ {
196
+ lookupDomainName := strings .Join (queryLabels [startLabelIdx :], "." )
197
+ nameservers , err := state .GetState ().LookupDomain (lookupDomainName )
198
+ if err != nil {
199
+ return "" , nil , err
200
+ }
201
+ if nameservers != nil {
202
+ ret := map [string ][]net.IP {}
203
+ for k , v := range nameservers {
204
+ k = k + `.`
205
+ ret [k ] = append (ret [k ], net .ParseIP (v ))
206
+ }
207
+ return dns .Fqdn (lookupDomainName ), ret , nil
208
+ }
209
+ }
210
+
211
+ // Query fallback servers, if configured
212
+ if len (cfg .Dns .FallbackServers ) > 0 {
213
+ // Pick random fallback server
214
+ fallbackServer := cfg .Dns .FallbackServers [rand .Intn (len (cfg .Dns .FallbackServers ))]
215
+ serverWithPort := fmt .Sprintf ("%s:53" , fallbackServer )
216
+ for startLabelIdx := 0 ; startLabelIdx < len (queryLabels ); startLabelIdx ++ {
217
+ lookupDomainName := dns .Fqdn (strings .Join (queryLabels [startLabelIdx :], "." ))
218
+ m := new (dns.Msg )
219
+ m .SetQuestion (lookupDomainName , dns .TypeNS )
220
+ m .RecursionDesired = false
221
+ in , err := doQuery (m , serverWithPort )
222
+ if err != nil {
223
+ return "" , nil , err
224
+ }
225
+ if in .Rcode == dns .RcodeSuccess {
226
+ if len (in .Answer ) > 0 {
227
+ ret := map [string ][]net.IP {}
228
+ for _ , answer := range in .Answer {
229
+ switch v := answer .(type ) {
230
+ case * dns.NS :
231
+ ns := v .Ns
232
+ ret [ns ] = make ([]net.IP , 0 )
233
+ // Query for matching A/AAAA records
234
+ m2 := new (dns.Msg )
235
+ m2 .SetQuestion (ns , dns .TypeA )
236
+ m2 .RecursionDesired = false
237
+ in2 , err := doQuery (m2 , serverWithPort )
238
+ if err != nil {
239
+ return "" , nil , err
240
+ }
241
+ for _ , answer2 := range in2 .Answer {
242
+ switch v := answer2 .(type ) {
243
+ case * dns.A :
244
+ ret [ns ] = append (ret [ns ], v .A )
245
+ case * dns.AAAA :
246
+ ret [ns ] = append (ret [ns ], v .AAAA )
247
+ }
248
+ }
249
+ }
250
+ }
251
+ if len (ret ) > 0 {
252
+ return lookupDomainName , ret , nil
253
+ }
254
+ }
255
+ }
256
+ }
257
+ }
258
+
259
+ return "" , nil , nil
260
+ }
0 commit comments