@@ -11,6 +11,7 @@ import (
1111 "crypto/x509"
1212 "encoding/base64"
1313 "encoding/binary"
14+ "errors"
1415 "fmt"
1516 "io"
1617 "io/ioutil"
@@ -920,3 +921,180 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
920921 defer ws .Close ()
921922 sendRecv (t , ws )
922923}
924+
925+ // TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
926+ func TestNetDialConnect (t * testing.T ) {
927+
928+ upgrader := Upgrader {}
929+ handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
930+ if IsWebSocketUpgrade (r ) {
931+ c , err := upgrader .Upgrade (w , r , http.Header {"X-Test-Host" : {r .Host }})
932+ if err != nil {
933+ t .Fatal (err )
934+ }
935+ c .Close ()
936+ } else {
937+ w .Header ().Set ("X-Test-Host" , r .Host )
938+ }
939+ })
940+
941+ server := httptest .NewServer (handler )
942+ defer server .Close ()
943+
944+ tlsServer := httptest .NewTLSServer (handler )
945+ defer tlsServer .Close ()
946+
947+ testUrls := map [* httptest.Server ]string {
948+ server : "ws://" + server .Listener .Addr ().String () + "/" ,
949+ tlsServer : "wss://" + tlsServer .Listener .Addr ().String () + "/" ,
950+ }
951+
952+ cas := rootCAs (t , tlsServer )
953+ tlsConfig := & tls.Config {
954+ RootCAs : cas ,
955+ ServerName : "example.com" ,
956+ InsecureSkipVerify : false ,
957+ }
958+
959+ tests := []struct {
960+ name string
961+ server * httptest.Server // server to use
962+ netDial func (network , addr string ) (net.Conn , error )
963+ netDialContext func (ctx context.Context , network , addr string ) (net.Conn , error )
964+ netDialTLSContext func (ctx context.Context , network , addr string ) (net.Conn , error )
965+ tlsClientConfig * tls.Config
966+ }{
967+
968+ {
969+ name : "HTTP server, all NetDial* defined, shall use NetDialContext" ,
970+ server : server ,
971+ netDial : func (network , addr string ) (net.Conn , error ) {
972+ return nil , errors .New ("NetDial should not be called" )
973+ },
974+ netDialContext : func (_ context.Context , network , addr string ) (net.Conn , error ) {
975+ return net .Dial (network , addr )
976+ },
977+ netDialTLSContext : func (_ context.Context , network , addr string ) (net.Conn , error ) {
978+ return nil , errors .New ("NetDialTLSContext should not be called" )
979+ },
980+ tlsClientConfig : nil ,
981+ },
982+ {
983+ name : "HTTP server, all NetDial* undefined" ,
984+ server : server ,
985+ netDial : nil ,
986+ netDialContext : nil ,
987+ netDialTLSContext : nil ,
988+ tlsClientConfig : nil ,
989+ },
990+ {
991+ name : "HTTP server, NetDialContext undefined, shall fallback to NetDial" ,
992+ server : server ,
993+ netDial : func (network , addr string ) (net.Conn , error ) {
994+ return net .Dial (network , addr )
995+ },
996+ netDialContext : nil ,
997+ netDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
998+ return nil , errors .New ("NetDialTLSContext should not be called" )
999+ },
1000+ tlsClientConfig : nil ,
1001+ },
1002+ {
1003+ name : "HTTPS server, all NetDial* defined, shall use NetDialTLSContext" ,
1004+ server : tlsServer ,
1005+ netDial : func (network , addr string ) (net.Conn , error ) {
1006+ return nil , errors .New ("NetDial should not be called" )
1007+ },
1008+ netDialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1009+ return nil , errors .New ("NetDialContext should not be called" )
1010+ },
1011+ netDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1012+ netConn , err := net .Dial (network , addr )
1013+ if err != nil {
1014+ return nil , err
1015+ }
1016+ tlsConn := tls .Client (netConn , tlsConfig )
1017+ err = tlsConn .Handshake ()
1018+ if err != nil {
1019+ return nil , err
1020+ }
1021+ return tlsConn , nil
1022+ },
1023+ tlsClientConfig : nil ,
1024+ },
1025+ {
1026+ name : "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake" ,
1027+ server : tlsServer ,
1028+ netDial : func (network , addr string ) (net.Conn , error ) {
1029+ return nil , errors .New ("NetDial should not be called" )
1030+ },
1031+ netDialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1032+ return net .Dial (network , addr )
1033+ },
1034+ netDialTLSContext : nil ,
1035+ tlsClientConfig : tlsConfig ,
1036+ },
1037+ {
1038+ name : "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake" ,
1039+ server : tlsServer ,
1040+ netDial : func (network , addr string ) (net.Conn , error ) {
1041+ return net .Dial (network , addr )
1042+ },
1043+ netDialContext : nil ,
1044+ netDialTLSContext : nil ,
1045+ tlsClientConfig : tlsConfig ,
1046+ },
1047+ {
1048+ name : "HTTPS server, all NetDial* undefined" ,
1049+ server : tlsServer ,
1050+ netDial : nil ,
1051+ netDialContext : nil ,
1052+ netDialTLSContext : nil ,
1053+ tlsClientConfig : tlsConfig ,
1054+ },
1055+ {
1056+ name : "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake" ,
1057+ server : tlsServer ,
1058+ netDial : func (network , addr string ) (net.Conn , error ) {
1059+ return nil , errors .New ("NetDial should not be called" )
1060+ },
1061+ netDialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1062+ return nil , errors .New ("NetDialContext should not be called" )
1063+ },
1064+ netDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1065+ netConn , err := net .Dial (network , addr )
1066+ if err != nil {
1067+ return nil , err
1068+ }
1069+ tlsConn := tls .Client (netConn , tlsConfig )
1070+ err = tlsConn .Handshake ()
1071+ if err != nil {
1072+ return nil , err
1073+ }
1074+ return tlsConn , nil
1075+ },
1076+ tlsClientConfig : & tls.Config {
1077+ RootCAs : nil ,
1078+ ServerName : "badserver.com" ,
1079+ InsecureSkipVerify : false ,
1080+ },
1081+ },
1082+ }
1083+
1084+ for _ , tc := range tests {
1085+ dialer := Dialer {
1086+ NetDial : tc .netDial ,
1087+ NetDialContext : tc .netDialContext ,
1088+ NetDialTLSContext : tc .netDialTLSContext ,
1089+ TLSClientConfig : tc .tlsClientConfig ,
1090+ }
1091+
1092+ // Test websocket dial
1093+ c , _ , err := dialer .Dial (testUrls [tc .server ], nil )
1094+ if err != nil {
1095+ t .Errorf ("FAILED %s, err: %s" , tc .name , err .Error ())
1096+ } else {
1097+ c .Close ()
1098+ }
1099+ }
1100+ }
0 commit comments