diff --git a/internal/dns/dns.go b/internal/dns/dns.go index f67eeec..283efbe 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -68,41 +68,25 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) { } } - // Check for known record for domain nameserver - records, err := state.GetState().LookupNameserverRecord( + // Check for known record from local storage + records, err := state.GetState().LookupRecords( + []string{dns.Type(r.Question[0].Qtype).String()}, strings.TrimSuffix(r.Question[0].Name, "."), ) if err != nil { - logger.Errorf("failed to lookup record in state: %s", err) + logger.Errorf("failed to lookup records in state: %s", err) return } if records != nil { // Assemble response m.SetReply(r) - for k, v := range records { - k = dns.Fqdn(k) - address := net.ParseIP(v) - // A or AAAA record - if address.To4() != nil { - // IPv4 - a := &dns.A{ - Hdr: dns.RR_Header{ - Name: k, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 999, - }, - A: address, - } - m.Answer = append(m.Answer, a) - } else { - // IPv6 - aaaa := &dns.AAAA{ - Hdr: dns.RR_Header{Name: k, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999}, - AAAA: address, - } - m.Answer = append(m.Answer, aaaa) + for _, tmpRecord := range records { + tmpRR, err := stateRecordToDnsRR(tmpRecord) + if err != nil { + logger.Errorf("failed to convert state record to dns.RR: %s", err) + return } + m.Answer = append(m.Answer, tmpRR) } // Send response if err := w.WriteMsg(m); err != nil { @@ -112,6 +96,7 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) { return } + // Check for any NS records for parent domains from local storage nameserverDomain, nameservers, err := findNameserversForDomain( r.Question[0].Name, ) @@ -182,6 +167,30 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) { return } + // Query fallback servers, if configured + if len(cfg.Dns.FallbackServers) > 0 { + // Pick random fallback server + fallbackServer := randomFallbackServer() + // Pass along query to chosen fallback server + resp, err := doQuery(r, fallbackServer, false) + if err != nil { + // Send failure response + m.SetRcode(r, dns.RcodeServerFailure) + if err := w.WriteMsg(m); err != nil { + logger.Errorf("failed to write response: %s", err) + } + logger.Errorf("failed to query domain nameserver: %s", err) + return + } else { + copyResponse(r, resp, m) + // Send response + if err := w.WriteMsg(m); err != nil { + logger.Errorf("failed to write response: %s", err) + } + return + } + } + // Return NXDOMAIN if we have no information about the requested domain or any of its parents m.SetRcode(r, dns.RcodeNameError) if err := w.WriteMsg(m); err != nil { @@ -189,6 +198,21 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) { } } +func stateRecordToDnsRR(record state.DomainRecord) (dns.RR, error) { + tmpTtl := "" + if record.Ttl > 0 { + tmpTtl = fmt.Sprintf("%d", record.Ttl) + } + tmpRR := fmt.Sprintf( + "%s %s IN %s %s", + record.Lhs, + tmpTtl, + record.Type, + record.Rhs, + ) + return dns.NewRR(tmpRR) +} + func copyResponse(req *dns.Msg, srcResp *dns.Msg, destResp *dns.Msg) { // Copy relevant data from original request and source response into destination response destResp.SetRcode(req, srcResp.MsgHdr.Rcode) @@ -279,8 +303,6 @@ func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) { func findNameserversForDomain( recordName string, ) (string, map[string][]net.IP, error) { - cfg := config.GetConfig() - // Split record name into labels and lookup each domain and parent until we get a hit queryLabels := dns.SplitDomainName(recordName) @@ -314,51 +336,6 @@ func findNameserversForDomain( } } - // Query fallback servers, if configured - if len(cfg.Dns.FallbackServers) > 0 { - // Pick random fallback server - fallbackServer := randomFallbackServer() - for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ { - lookupDomainName := dns.Fqdn( - strings.Join(queryLabels[startLabelIdx:], "."), - ) - m := createQuery(lookupDomainName, dns.TypeNS) - in, err := doQuery(m, fallbackServer, false) - if err != nil { - return "", nil, err - } - if in.Rcode == dns.RcodeSuccess { - if len(in.Answer) > 0 { - ret := map[string][]net.IP{} - for _, answer := range in.Answer { - switch v := answer.(type) { - case *dns.NS: - ns := v.Ns - ret[ns] = make([]net.IP, 0) - // Query for matching A/AAAA records - m2 := createQuery(ns, dns.TypeA) - in2, err := doQuery(m2, fallbackServer, false) - if err != nil { - return "", nil, err - } - for _, answer2 := range in2.Answer { - switch v := answer2.(type) { - case *dns.A: - ret[ns] = append(ret[ns], v.A) - case *dns.AAAA: - ret[ns] = append(ret[ns], v.AAAA) - } - } - } - } - if len(ret) > 0 { - return lookupDomainName, ret, nil - } - } - } - } - } - return "", nil, nil } diff --git a/internal/state/state.go b/internal/state/state.go index 5ec3031..3724a47 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -259,44 +259,6 @@ func (s *State) LookupRecords(recordTypes []string, recordName string) ([]Domain return ret, nil } -// LookupNameserverRecord searches the domain nameserver entries for one matching the requested record -func (s *State) LookupNameserverRecord( - recordName string, -) (map[string]string, error) { - ret := map[string]string{} - err := s.db.View(func(txn *badger.Txn) error { - opts := badger.DefaultIteratorOptions - // Makes key scans faster - opts.PrefetchValues = false - it := txn.NewIterator(opts) - defer it.Close() - for it.Rewind(); it.Valid(); it.Next() { - item := it.Item() - k := item.Key() - if strings.HasSuffix( - string(k), - fmt.Sprintf("_nameserver_%s", recordName), - ) { - err := item.Value(func(v []byte) error { - ret[recordName] = string(v) - return nil - }) - if err != nil { - return err - } - } - } - return nil - }) - if err != nil { - return nil, err - } - if len(ret) == 0 { - return nil, nil - } - return ret, nil -} - func GetState() *State { return globalState }