Skip to content

feat: explicitly query fallback servers for NS records #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 116 additions & 36 deletions internal/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,18 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
}
}

// Split query name into labels and lookup each domain and parent until we get a hit
queryLabels := dns.SplitDomainName(r.Question[0].Name)
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
lookupDomainName := strings.Join(queryLabels[startLabelIdx:], ".")
nameServers, err := state.GetState().LookupDomain(lookupDomainName)
if err != nil {
logger.Errorf("failed to lookup domain: %s", err)
}
if nameServers == nil {
continue
}
nameserverDomain, nameservers, err := findNameserversForDomain(r.Question[0].Name)
if err != nil {
logger.Errorf("failed to lookup nameservers for %s: %s", r.Question[0].Name, err)
}
if nameservers != nil {
// Assemble response
m.SetReply(r)
if cfg.Dns.RecursionEnabled {
// Pick random nameserver for domain
tmpNameservers := []string{}
for nameserver := range nameServers {
tmpNameservers = append(tmpNameservers, nameserver)
}
tmpNameserver := nameServers[tmpNameservers[rand.Intn(len(tmpNameservers))]]
tmpNameserver := randomNameserverAddress(nameservers)
// Query the random domain nameserver we picked above
resp, err := queryServer(r, tmpNameserver)
resp, err := queryServer(r, tmpNameserver.String())
if err != nil {
// Send failure response
m.SetRcode(r, dns.RcodeServerFailure)
Expand All @@ -87,31 +77,30 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
return
}
} else {
for nameserver, ipAddress := range nameServers {
// Add trailing dot to make everybody happy
nameserver = nameserver + `.`
for nameserver, addresses := range nameservers {
// NS record
ns := &dns.NS{
Hdr: dns.RR_Header{Name: (lookupDomainName + `.`), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 999},
Hdr: dns.RR_Header{Name: (nameserverDomain), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 999},
Ns: nameserver,
}
m.Ns = append(m.Ns, ns)
// A or AAAA record
ipAddr := net.ParseIP(ipAddress)
if ipAddr.To4() != nil {
// IPv4
a := &dns.A{
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 999},
A: ipAddr,
}
m.Extra = append(m.Extra, a)
} else {
// IPv6
aaaa := &dns.AAAA{
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999},
AAAA: ipAddr,
for _, address := range addresses {
// A or AAAA record
if address.To4() != nil {
// IPv4
a := &dns.A{
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 999},
A: address,
}
m.Extra = append(m.Extra, a)
} else {
// IPv6
aaaa := &dns.AAAA{
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999},
AAAA: address,
}
m.Extra = append(m.Extra, aaaa)
}
m.Extra = append(m.Extra, aaaa)
}
}
}
Expand Down Expand Up @@ -178,3 +167,94 @@ func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) {
in, err := dns.Exchange(m, fmt.Sprintf("%s:53", nameserver))
return in, err
}

func randomNameserverAddress(nameservers map[string][]net.IP) net.IP {
// Put all namserver addresses in single list
tmpNameservers := []net.IP{}
for _, addresses := range nameservers {
tmpNameservers = append(tmpNameservers, addresses...)
}
tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))]
return tmpNameserver
}

func doQuery(msg *dns.Msg, address string) (*dns.Msg, error) {
logger := logging.GetLogger()
logger.Debugf("querying %s: %s %s", address, dns.Type(msg.Question[0].Qtype).String(), msg.Question[0].Name)
resp, err := dns.Exchange(msg, address)
return resp, err
}

func findNameserversForDomain(recordName string) (string, map[string][]net.IP, error) {
cfg := config.GetConfig()

// Split record name into labels and lookup each domain and parent until we get a hit
queryLabels := dns.SplitDomainName(recordName)

// Check on-chain domains first
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
lookupDomainName := strings.Join(queryLabels[startLabelIdx:], ".")
nameservers, err := state.GetState().LookupDomain(lookupDomainName)
if err != nil {
return "", nil, err
}
if nameservers != nil {
ret := map[string][]net.IP{}
for k, v := range nameservers {
k = k + `.`
ret[k] = append(ret[k], net.ParseIP(v))
}
return dns.Fqdn(lookupDomainName), ret, nil
}
}

// Query fallback servers, if configured
if len(cfg.Dns.FallbackServers) > 0 {
// Pick random fallback server
fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))]
serverWithPort := fmt.Sprintf("%s:53", fallbackServer)
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
lookupDomainName := dns.Fqdn(strings.Join(queryLabels[startLabelIdx:], "."))
m := new(dns.Msg)
m.SetQuestion(lookupDomainName, dns.TypeNS)
m.RecursionDesired = false
in, err := doQuery(m, serverWithPort)
if err != nil {
return "", nil, err
}
if in.Rcode == dns.RcodeSuccess {
if len(in.Answer) > 0 {
ret := map[string][]net.IP{}
for _, answer := range in.Answer {
switch v := answer.(type) {
case *dns.NS:
ns := v.Ns
ret[ns] = make([]net.IP, 0)
// Query for matching A/AAAA records
m2 := new(dns.Msg)
m2.SetQuestion(ns, dns.TypeA)
m2.RecursionDesired = false
in2, err := doQuery(m2, serverWithPort)
if err != nil {
return "", nil, err
}
for _, answer2 := range in2.Answer {
switch v := answer2.(type) {
case *dns.A:
ret[ns] = append(ret[ns], v.A)
case *dns.AAAA:
ret[ns] = append(ret[ns], v.AAAA)
}
}
}
}
if len(ret) > 0 {
return lookupDomainName, ret, nil
}
}
}
}
}

return "", nil, nil
}