diff --git a/internal/dns/dns.go b/internal/dns/dns.go index 6a8b77a..3798f43 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -48,28 +48,18 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) { } } - // Split query name into labels and lookup each domain and parent until we get a hit - queryLabels := dns.SplitDomainName(r.Question[0].Name) - for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ { - lookupDomainName := strings.Join(queryLabels[startLabelIdx:], ".") - nameServers, err := state.GetState().LookupDomain(lookupDomainName) - if err != nil { - logger.Errorf("failed to lookup domain: %s", err) - } - if nameServers == nil { - continue - } + nameserverDomain, nameservers, err := findNameserversForDomain(r.Question[0].Name) + if err != nil { + logger.Errorf("failed to lookup nameservers for %s: %s", r.Question[0].Name, err) + } + if nameservers != nil { // Assemble response m.SetReply(r) if cfg.Dns.RecursionEnabled { // Pick random nameserver for domain - tmpNameservers := []string{} - for nameserver := range nameServers { - tmpNameservers = append(tmpNameservers, nameserver) - } - tmpNameserver := nameServers[tmpNameservers[rand.Intn(len(tmpNameservers))]] + tmpNameserver := randomNameserverAddress(nameservers) // Query the random domain nameserver we picked above - resp, err := queryServer(r, tmpNameserver) + resp, err := queryServer(r, tmpNameserver.String()) if err != nil { // Send failure response m.SetRcode(r, dns.RcodeServerFailure) @@ -87,31 +77,30 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) { return } } else { - for nameserver, ipAddress := range nameServers { - // Add trailing dot to make everybody happy - nameserver = nameserver + `.` + for nameserver, addresses := range nameservers { // NS record ns := &dns.NS{ - Hdr: dns.RR_Header{Name: (lookupDomainName + `.`), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 999}, + Hdr: dns.RR_Header{Name: (nameserverDomain), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 999}, Ns: nameserver, } m.Ns = append(m.Ns, ns) - // A or AAAA record - ipAddr := net.ParseIP(ipAddress) - if ipAddr.To4() != nil { - // IPv4 - a := &dns.A{ - Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 999}, - A: ipAddr, - } - m.Extra = append(m.Extra, a) - } else { - // IPv6 - aaaa := &dns.AAAA{ - Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999}, - AAAA: ipAddr, + for _, address := range addresses { + // A or AAAA record + if address.To4() != nil { + // IPv4 + a := &dns.A{ + Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 999}, + A: address, + } + m.Extra = append(m.Extra, a) + } else { + // IPv6 + aaaa := &dns.AAAA{ + Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999}, + AAAA: address, + } + m.Extra = append(m.Extra, aaaa) } - m.Extra = append(m.Extra, aaaa) } } } @@ -178,3 +167,94 @@ func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) { in, err := dns.Exchange(m, fmt.Sprintf("%s:53", nameserver)) return in, err } + +func randomNameserverAddress(nameservers map[string][]net.IP) net.IP { + // Put all namserver addresses in single list + tmpNameservers := []net.IP{} + for _, addresses := range nameservers { + tmpNameservers = append(tmpNameservers, addresses...) + } + tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))] + return tmpNameserver +} + +func doQuery(msg *dns.Msg, address string) (*dns.Msg, error) { + logger := logging.GetLogger() + logger.Debugf("querying %s: %s %s", address, dns.Type(msg.Question[0].Qtype).String(), msg.Question[0].Name) + resp, err := dns.Exchange(msg, address) + return resp, err +} + +func findNameserversForDomain(recordName string) (string, map[string][]net.IP, error) { + cfg := config.GetConfig() + + // Split record name into labels and lookup each domain and parent until we get a hit + queryLabels := dns.SplitDomainName(recordName) + + // Check on-chain domains first + for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ { + lookupDomainName := strings.Join(queryLabels[startLabelIdx:], ".") + nameservers, err := state.GetState().LookupDomain(lookupDomainName) + if err != nil { + return "", nil, err + } + if nameservers != nil { + ret := map[string][]net.IP{} + for k, v := range nameservers { + k = k + `.` + ret[k] = append(ret[k], net.ParseIP(v)) + } + return dns.Fqdn(lookupDomainName), ret, nil + } + } + + // Query fallback servers, if configured + if len(cfg.Dns.FallbackServers) > 0 { + // Pick random fallback server + fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))] + serverWithPort := fmt.Sprintf("%s:53", fallbackServer) + for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ { + lookupDomainName := dns.Fqdn(strings.Join(queryLabels[startLabelIdx:], ".")) + m := new(dns.Msg) + m.SetQuestion(lookupDomainName, dns.TypeNS) + m.RecursionDesired = false + in, err := doQuery(m, serverWithPort) + if err != nil { + return "", nil, err + } + if in.Rcode == dns.RcodeSuccess { + if len(in.Answer) > 0 { + ret := map[string][]net.IP{} + for _, answer := range in.Answer { + switch v := answer.(type) { + case *dns.NS: + ns := v.Ns + ret[ns] = make([]net.IP, 0) + // Query for matching A/AAAA records + m2 := new(dns.Msg) + m2.SetQuestion(ns, dns.TypeA) + m2.RecursionDesired = false + in2, err := doQuery(m2, serverWithPort) + if err != nil { + return "", nil, err + } + for _, answer2 := range in2.Answer { + switch v := answer2.(type) { + case *dns.A: + ret[ns] = append(ret[ns], v.A) + case *dns.AAAA: + ret[ns] = append(ret[ns], v.AAAA) + } + } + } + } + if len(ret) > 0 { + return lookupDomainName, ret, nil + } + } + } + } + } + + return "", nil, nil +}