Skip to content

Commit 0ba1620

Browse files
authored
Merge pull request #256 from projectdiscovery/proxy-fallback-tcp
if udp fails, fallback to tcp connection
2 parents 0d8c0ee + 5513869 commit 0ba1620

File tree

4 files changed

+43
-7
lines changed

4 files changed

+43
-7
lines changed

client.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,20 @@ func NewWithOptions(options Options) (*Client, error) {
8585
doh.WithProxy(options.Proxy), // no-op if empty
8686
)
8787

88+
// If proxy is specified, force TCP for all resolvers
89+
if options.Proxy != "" {
90+
for i, resolver := range parsedBaseResolvers {
91+
if networkResolver, ok := resolver.(*NetworkResolver); ok && networkResolver.Protocol == UDP {
92+
// Convert UDP resolvers to TCP when proxy is specified
93+
parsedBaseResolvers[i] = &NetworkResolver{
94+
Protocol: TCP,
95+
Host: networkResolver.Host,
96+
Port: networkResolver.Port,
97+
}
98+
}
99+
}
100+
}
101+
88102
udpDialer := &net.Dialer{LocalAddr: options.GetLocalAddr(UDP)}
89103
tcpDialer := &net.Dialer{LocalAddr: options.GetLocalAddr(TCP)}
90104
dotDialer := &net.Dialer{LocalAddr: options.GetLocalAddr(TCP)}
@@ -426,7 +440,17 @@ func (c *Client) queryMultiple(host string, requestTypes []uint16, resolver Reso
426440
} else {
427441
switch r.Protocol {
428442
case TCP:
429-
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
443+
if c.tcpProxy != nil {
444+
var tcpConn *dns.Conn
445+
tcpConn, err = c.dialWithProxy(c.tcpProxy, "tcp", resolver.String())
446+
if err != nil {
447+
break
448+
}
449+
defer tcpConn.Close()
450+
resp, _, err = c.tcpClient.ExchangeWithConn(msg, tcpConn)
451+
} else {
452+
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
453+
}
430454
case UDP:
431455
if c.options.ConnectionPoolThreads > 1 {
432456
if udpConnPool, ok := c.udpConnPool.Get(resolver.String()); ok {

client_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func TestConsistentResolve(t *testing.T) {
5555

5656
var last string
5757
for i := 0; i < 10; i++ {
58-
d, err := client.Resolve("example.com")
58+
d, err := client.Resolve("scanme.sh")
5959
require.Nil(t, err, "could not resolve dns")
6060

6161
if last != "" {
@@ -69,7 +69,7 @@ func TestConsistentResolve(t *testing.T) {
6969
func TestUDP(t *testing.T) {
7070
client, _ := New([]string{"1.1.1.1:53", "udp:8.8.8.8"}, 5)
7171

72-
d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA})
72+
d, err := client.QueryMultiple("scanme.sh", []uint16{dns.TypeA})
7373
require.Nil(t, err)
7474

7575
// From current dig result
@@ -79,7 +79,7 @@ func TestUDP(t *testing.T) {
7979
func TestTCP(t *testing.T) {
8080
client, _ := New([]string{"tcp:1.1.1.1:53", "tcp:8.8.8.8"}, 5)
8181

82-
d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA})
82+
d, err := client.QueryMultiple("scanme.sh", []uint16{dns.TypeA})
8383
require.Nil(t, err)
8484

8585
// From current dig result
@@ -89,7 +89,7 @@ func TestTCP(t *testing.T) {
8989
func TestDOH(t *testing.T) {
9090
client, _ := New([]string{"doh:https://doh.opendns.com/dns-query:post", "doh:https://doh.opendns.com/dns-query:get"}, 5)
9191

92-
d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA})
92+
d, err := client.QueryMultiple("scanme.sh", []uint16{dns.TypeA})
9393
require.Nil(t, err)
9494

9595
// From current dig result
@@ -99,7 +99,7 @@ func TestDOH(t *testing.T) {
9999
func TestDOT(t *testing.T) {
100100
client, _ := New([]string{"dot:dns.google:853", "dot:1dot1dot1dot1.cloudflare-dns.com"}, 5)
101101

102-
d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA})
102+
d, err := client.QueryMultiple("scanme.sh", []uint16{dns.TypeA})
103103
require.Nil(t, err)
104104

105105
// From current dig result

doh/doh_client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ func TestConsistentResolve(t *testing.T) {
1111
client := New()
1212
var lastAnswer string
1313
for i := 0; i < 10; i++ {
14-
d, err := client.Query("example.com", A)
14+
d, err := client.Query("scanme.sh", A)
1515
require.Nil(t, err, "could not resolve dns")
1616
if lastAnswer == "" {
1717
lastAnswer = d.Answer[0].Data

options.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@ import (
1010
var (
1111
ErrMaxRetriesZero = errors.New("retries must be at least 1")
1212
ErrResolversEmpty = errors.New("resolvers list must not be empty")
13+
14+
BaseResolvers = []string{
15+
"1.1.1.1:53",
16+
"1.0.0.1:53",
17+
"8.8.8.8:53",
18+
"8.8.4.4:53",
19+
}
20+
DefaultOptions = Options{
21+
BaseResolvers: BaseResolvers,
22+
MaxRetries: 1,
23+
Timeout: 3 * time.Second,
24+
}
1325
)
1426

1527
type Options struct {

0 commit comments

Comments
 (0)