Skip to content

Commit f50b198

Browse files
committed
Dialer: add optional methods NetDialTLS and NetDialTLSContext
Fixes issue: #745 With the previous interface, NetDial and NetDialContext were used for both TLS and non-TLS TCP connections, and afterwards TLSClientConfig was used to do the TLS handshake. While this API works for most cases, it prevents from using more advance authentication methods during the TLS handshake, as this is out of the control of the user. This commits introduces another pair of dial methods, NetDialTLS and NetDialTLSContext which are used when dialing for TLS/TCP. The code then assumes that the handshake is done there and TLSClientConfig is not used. This API change is fully backwards compatible and it better aligns with net/http.Transport API, which has these four dial flavors. See: https://pkg.go.dev/net/http#Transport Signed-off-by: Lluis Campos <[email protected]>
1 parent e8629af commit f50b198

File tree

2 files changed

+313
-8
lines changed

2 files changed

+313
-8
lines changed

client.go

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,21 @@ type Dialer struct {
5454
NetDial func(network, addr string) (net.Conn, error)
5555

5656
// NetDialContext specifies the dial function for creating TCP connections. If
57-
// NetDialContext is nil, net.DialContext is used.
57+
// NetDialContext is nil, NetDial is used.
5858
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
5959

60+
// NetDialTLS specifies the dial function for creating TLS/TCP connections. If
61+
// NetDialTLS is nil, net.Dial is used.
62+
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
63+
// is done there and TLSClientConfig is ignored.
64+
NetDialTLS func(network, addr string) (net.Conn, error)
65+
66+
// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
67+
// NetDialTLSContext is nil, NetDialTLS is used.
68+
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
69+
// is done there and TLSClientConfig is ignored.
70+
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
71+
6072
// Proxy specifies a function to return a proxy for a given
6173
// Request. If the function returns a non-nil error, the
6274
// request is aborted with the provided error.
@@ -65,6 +77,8 @@ type Dialer struct {
6577

6678
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
6779
// If nil, the default configuration is used.
80+
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
81+
// is done there and TLSClientConfig is ignored.
6882
TLSClientConfig *tls.Config
6983

7084
// HandshakeTimeout specifies the duration for the handshake to complete.
@@ -237,13 +251,34 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
237251
// Get network dial function.
238252
var netDial func(network, add string) (net.Conn, error)
239253

240-
if d.NetDialContext != nil {
241-
netDial = func(network, addr string) (net.Conn, error) {
242-
return d.NetDialContext(ctx, network, addr)
254+
switch u.Scheme {
255+
case "http":
256+
if d.NetDialContext != nil {
257+
netDial = func(network, addr string) (net.Conn, error) {
258+
return d.NetDialContext(ctx, network, addr)
259+
}
260+
} else if d.NetDial != nil {
261+
netDial = d.NetDial
243262
}
244-
} else if d.NetDial != nil {
245-
netDial = d.NetDial
246-
} else {
263+
case "https":
264+
if d.NetDialTLSContext != nil {
265+
netDial = func(network, addr string) (net.Conn, error) {
266+
return d.NetDialTLSContext(ctx, network, addr)
267+
}
268+
} else if d.NetDialTLS != nil {
269+
netDial = d.NetDialTLS
270+
} else if d.NetDialContext != nil {
271+
netDial = func(network, addr string) (net.Conn, error) {
272+
return d.NetDialContext(ctx, network, addr)
273+
}
274+
} else if d.NetDial != nil {
275+
netDial = d.NetDial
276+
}
277+
default:
278+
return nil, nil, errMalformedURL
279+
}
280+
281+
if netDial == nil {
247282
netDialer := &net.Dialer{}
248283
netDial = func(network, addr string) (net.Conn, error) {
249284
return netDialer.DialContext(ctx, network, addr)
@@ -304,7 +339,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
304339
}
305340
}()
306341

307-
if u.Scheme == "https" {
342+
if u.Scheme == "https" && d.NetDialTLSContext == nil && d.NetDialTLS == nil {
343+
// If either NetDialTLS or NetDialTLSContext are set, assume that
344+
// the TLS handshake has already been done
345+
308346
cfg := cloneTLSConfig(d.TLSClientConfig)
309347
if cfg.ServerName == "" {
310348
cfg.ServerName = hostNoPort

client_server_test.go

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,3 +920,270 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
920920
defer ws.Close()
921921
sendRecv(t, ws)
922922
}
923+
924+
// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
925+
func TestNetDialConnect(t *testing.T) {
926+
927+
upgrader := Upgrader{}
928+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
929+
if IsWebSocketUpgrade(r) {
930+
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
931+
if err != nil {
932+
t.Fatal(err)
933+
}
934+
c.Close()
935+
} else {
936+
w.Header().Set("X-Test-Host", r.Host)
937+
}
938+
})
939+
940+
server := httptest.NewServer(handler)
941+
defer server.Close()
942+
943+
tlsServer := httptest.NewTLSServer(handler)
944+
defer tlsServer.Close()
945+
946+
testUrls := map[*httptest.Server]string{
947+
server: "ws://" + server.Listener.Addr().String() + "/",
948+
tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/",
949+
}
950+
951+
cas := rootCAs(t, tlsServer)
952+
tlsConfig := &tls.Config{
953+
RootCAs: cas,
954+
ServerName: "example.com",
955+
InsecureSkipVerify: false,
956+
}
957+
958+
tests := map[string]struct {
959+
server *httptest.Server // server to use
960+
netDial func(network, addr string) (net.Conn, error)
961+
netDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
962+
netDialTLS func(network, addr string) (net.Conn, error)
963+
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
964+
tlsClientConfig *tls.Config
965+
}{
966+
"HTTP server, all NetDial* defined, shall use NetDialContext": {
967+
server: server,
968+
netDial: func(network, addr string) (net.Conn, error) {
969+
t.Error("NetDial should not be called")
970+
t.FailNow()
971+
return nil, nil
972+
},
973+
netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
974+
return net.Dial(network, addr)
975+
},
976+
netDialTLS: func(network, addr string) (net.Conn, error) {
977+
t.Error("NetDialTLS should not be called")
978+
t.FailNow()
979+
return nil, nil
980+
},
981+
netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) {
982+
t.Error("NetDialTLSContext should not be called")
983+
t.FailNow()
984+
return nil, nil
985+
},
986+
tlsClientConfig: nil,
987+
},
988+
"HTTP server, all NetDial* undefined": {
989+
server: server,
990+
netDial: nil,
991+
netDialContext: nil,
992+
netDialTLS: nil,
993+
netDialTLSContext: nil,
994+
tlsClientConfig: nil,
995+
},
996+
"HTTP server, NetDialContext undefined, shall fallback to NetDial": {
997+
server: server,
998+
netDial: func(network, addr string) (net.Conn, error) {
999+
return net.Dial(network, addr)
1000+
},
1001+
netDialContext: nil,
1002+
netDialTLS: func(network, addr string) (net.Conn, error) {
1003+
t.Error("NetDialTLS should not be called")
1004+
t.FailNow()
1005+
return nil, nil
1006+
},
1007+
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1008+
t.Error("NetDialTLSContext should not be called")
1009+
t.FailNow()
1010+
return nil, nil
1011+
},
1012+
tlsClientConfig: nil,
1013+
},
1014+
"HTTPS server, all NetDial* defined, shall use NetDialTLSContext": {
1015+
server: tlsServer,
1016+
netDial: func(network, addr string) (net.Conn, error) {
1017+
t.Error("NetDial should not be called")
1018+
t.FailNow()
1019+
return nil, nil
1020+
},
1021+
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1022+
t.Error("NetDialContext should not be called")
1023+
t.FailNow()
1024+
return nil, nil
1025+
},
1026+
netDialTLS: func(network, addr string) (net.Conn, error) {
1027+
t.Error("NetDialTLS should not be called")
1028+
t.FailNow()
1029+
return nil, nil
1030+
},
1031+
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1032+
netConn, err := net.Dial(network, addr)
1033+
if err != nil {
1034+
return nil, err
1035+
}
1036+
tlsConn := tls.Client(netConn, tlsConfig)
1037+
err = tlsConn.Handshake()
1038+
if err != nil {
1039+
return nil, err
1040+
}
1041+
return tlsConn, nil
1042+
},
1043+
tlsClientConfig: nil,
1044+
},
1045+
"HTTPS server, NetDialTLSContext undefined, shall fallback to NetTLSDial": {
1046+
server: tlsServer,
1047+
netDial: func(network, addr string) (net.Conn, error) {
1048+
t.Error("NetDial should not be called")
1049+
t.FailNow()
1050+
return nil, nil
1051+
},
1052+
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1053+
t.Error("NetDialContext should not be called")
1054+
t.FailNow()
1055+
return nil, nil
1056+
},
1057+
netDialTLS: func(network, addr string) (net.Conn, error) {
1058+
netConn, err := net.Dial(network, addr)
1059+
if err != nil {
1060+
return nil, err
1061+
}
1062+
tlsConn := tls.Client(netConn, tlsConfig)
1063+
err = tlsConn.Handshake()
1064+
if err != nil {
1065+
return nil, err
1066+
}
1067+
return tlsConn, nil
1068+
},
1069+
netDialTLSContext: nil,
1070+
tlsClientConfig: nil,
1071+
},
1072+
"HTTPS server, NetDialTLS* undefined, shall fallback to NetDialContext and do handshake": {
1073+
server: tlsServer,
1074+
netDial: func(network, addr string) (net.Conn, error) {
1075+
t.Error("NetDial should not be called")
1076+
t.FailNow()
1077+
return nil, nil
1078+
},
1079+
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1080+
return net.Dial(network, addr)
1081+
},
1082+
netDialTLS: nil,
1083+
netDialTLSContext: nil,
1084+
tlsClientConfig: tlsConfig,
1085+
},
1086+
"HTTPS server, NetDialTLS* and NetDialContext undefined, shall fallback to NetDial and do handshake": {
1087+
server: tlsServer,
1088+
netDial: func(network, addr string) (net.Conn, error) {
1089+
return net.Dial(network, addr)
1090+
},
1091+
netDialContext: nil,
1092+
netDialTLS: nil,
1093+
netDialTLSContext: nil,
1094+
tlsClientConfig: tlsConfig,
1095+
},
1096+
"HTTPS server, all NetDial* undefined": {
1097+
server: tlsServer,
1098+
netDial: nil,
1099+
netDialContext: nil,
1100+
netDialTLS: nil,
1101+
netDialTLSContext: nil,
1102+
tlsClientConfig: tlsConfig,
1103+
},
1104+
"HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake": {
1105+
server: tlsServer,
1106+
netDial: func(network, addr string) (net.Conn, error) {
1107+
t.Error("NetDial should not be called")
1108+
t.FailNow()
1109+
return nil, nil
1110+
},
1111+
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1112+
t.Error("NetDialContext should not be called")
1113+
t.FailNow()
1114+
return nil, nil
1115+
},
1116+
netDialTLS: func(network, addr string) (net.Conn, error) {
1117+
t.Error("NetDialTLS should not be called")
1118+
t.FailNow()
1119+
return nil, nil
1120+
},
1121+
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1122+
netConn, err := net.Dial(network, addr)
1123+
if err != nil {
1124+
return nil, err
1125+
}
1126+
tlsConn := tls.Client(netConn, tlsConfig)
1127+
err = tlsConn.Handshake()
1128+
if err != nil {
1129+
return nil, err
1130+
}
1131+
return tlsConn, nil
1132+
},
1133+
tlsClientConfig: &tls.Config{
1134+
RootCAs: nil,
1135+
ServerName: "badserver.com",
1136+
InsecureSkipVerify: false,
1137+
},
1138+
},
1139+
"HTTPS server, NetDialTLS defined, dummy TlsClientConfig defined, shall not do handshake": {
1140+
server: tlsServer,
1141+
netDial: func(network, addr string) (net.Conn, error) {
1142+
t.Error("NetDial should not be called")
1143+
t.FailNow()
1144+
return nil, nil
1145+
},
1146+
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1147+
t.Error("NetDialContext should not be called")
1148+
t.FailNow()
1149+
return nil, nil
1150+
},
1151+
netDialTLS: func(network, addr string) (net.Conn, error) {
1152+
netConn, err := net.Dial(network, addr)
1153+
if err != nil {
1154+
return nil, err
1155+
}
1156+
tlsConn := tls.Client(netConn, tlsConfig)
1157+
err = tlsConn.Handshake()
1158+
if err != nil {
1159+
return nil, err
1160+
}
1161+
return tlsConn, nil
1162+
},
1163+
netDialTLSContext: nil,
1164+
tlsClientConfig: &tls.Config{
1165+
RootCAs: nil,
1166+
ServerName: "badserver.com",
1167+
InsecureSkipVerify: false,
1168+
},
1169+
},
1170+
}
1171+
1172+
for name, tc := range tests {
1173+
dialer := Dialer{
1174+
NetDial: tc.netDial,
1175+
NetDialContext: tc.netDialContext,
1176+
NetDialTLS: tc.netDialTLS,
1177+
NetDialTLSContext: tc.netDialTLSContext,
1178+
TLSClientConfig: tc.tlsClientConfig,
1179+
}
1180+
1181+
// Test websocket dial
1182+
c, _, err := dialer.Dial(testUrls[tc.server], nil)
1183+
if err != nil {
1184+
t.Errorf("FAILED %s, err: %s", name, err.Error())
1185+
} else {
1186+
c.Close()
1187+
}
1188+
}
1189+
}

0 commit comments

Comments
 (0)