diff --git a/internal/config/config.go b/internal/config/config.go index b3eae14..812bc00 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -41,8 +41,9 @@ type LoggingConfig struct { } type DnsConfig struct { - ListenAddress string `yaml:"address" envconfig:"DNS_LISTEN_ADDRESS"` - ListenPort uint `yaml:"port" envconfig:"DNS_LISTEN_PORT"` + ListenAddress string `yaml:"address" envconfig:"DNS_LISTEN_ADDRESS"` + ListenPort uint `yaml:"port" envconfig:"DNS_LISTEN_PORT"` + FallbackServers []string `yaml:"fallbackServers" envconfig:"DNS_FALLBACK_SERVERS"` } type DebugConfig struct { @@ -78,6 +79,8 @@ var globalConfig = &Config{ Dns: DnsConfig{ ListenAddress: "", ListenPort: 8053, + // hdns.io + FallbackServers: []string{"103.196.38.38", "103.196.38.39", "103.196.38.40"}, }, Debug: DebugConfig{ ListenAddress: "localhost", diff --git a/internal/dns/dns.go b/internal/dns/dns.go index adae538..f15f7cc 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -2,6 +2,7 @@ package dns import ( "fmt" + "math/rand" "net" "strings" @@ -84,9 +85,53 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) { return } + // Query fallback servers + fallbackResp, err := queryFallbackServer(r) + if err != nil { + // Send failure response + m.SetRcode(r, dns.RcodeServerFailure) + if err := w.WriteMsg(m); err != nil { + logger.Errorf("failed to write response: %s", err) + } + logger.Errorf("failed to query fallback server: %s", err) + return + } else { + // Copy relevant data from fallback response into our response + m.SetRcode(r, fallbackResp.MsgHdr.Rcode) + m.RecursionDesired = r.RecursionDesired + m.RecursionAvailable = fallbackResp.RecursionAvailable + if fallbackResp.Ns != nil { + m.Ns = append(m.Ns, fallbackResp.Ns...) + } + if fallbackResp.Answer != nil { + m.Answer = append(m.Answer, fallbackResp.Answer...) + } + if fallbackResp.Extra != nil { + m.Extra = append(m.Extra, fallbackResp.Extra...) + } + // Send response + if err := w.WriteMsg(m); err != nil { + logger.Errorf("failed to write response: %s", err) + } + return + } + // Return NXDOMAIN if we have no information about the requested domain or any of its parents m.SetRcode(r, dns.RcodeNameError) if err := w.WriteMsg(m); err != nil { logger.Errorf("failed to write response: %s", err) } } + +func queryFallbackServer(req *dns.Msg) (*dns.Msg, error) { + // Pick random fallback server + cfg := config.GetConfig() + fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))] + // Query chosen server + m := new(dns.Msg) + m.Id = dns.Id() + m.RecursionDesired = req.RecursionDesired + m.Question = append(m.Question, req.Question...) + in, err := dns.Exchange(m, fmt.Sprintf("%s:53", fallbackServer)) + return in, err +}