@@ -920,3 +920,270 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
920920 defer ws .Close ()
921921 sendRecv (t , ws )
922922}
923+
924+ // TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
925+ func TestNetDialConnect (t * testing.T ) {
926+
927+ upgrader := Upgrader {}
928+ handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
929+ if IsWebSocketUpgrade (r ) {
930+ c , err := upgrader .Upgrade (w , r , http.Header {"X-Test-Host" : {r .Host }})
931+ if err != nil {
932+ t .Fatal (err )
933+ }
934+ c .Close ()
935+ } else {
936+ w .Header ().Set ("X-Test-Host" , r .Host )
937+ }
938+ })
939+
940+ server := httptest .NewServer (handler )
941+ defer server .Close ()
942+
943+ tlsServer := httptest .NewTLSServer (handler )
944+ defer tlsServer .Close ()
945+
946+ testUrls := map [* httptest.Server ]string {
947+ server : "ws://" + server .Listener .Addr ().String () + "/" ,
948+ tlsServer : "wss://" + tlsServer .Listener .Addr ().String () + "/" ,
949+ }
950+
951+ cas := rootCAs (t , tlsServer )
952+ tlsConfig := & tls.Config {
953+ RootCAs : cas ,
954+ ServerName : "example.com" ,
955+ InsecureSkipVerify : false ,
956+ }
957+
958+ tests := map [string ]struct {
959+ server * httptest.Server // server to use
960+ netDial func (network , addr string ) (net.Conn , error )
961+ netDialContext func (ctx context.Context , network , addr string ) (net.Conn , error )
962+ netDialTLS func (network , addr string ) (net.Conn , error )
963+ netDialTLSContext func (ctx context.Context , network , addr string ) (net.Conn , error )
964+ tlsClientConfig * tls.Config
965+ }{
966+ "HTTP server, all NetDial* defined, shall use NetDialContext" : {
967+ server : server ,
968+ netDial : func (network , addr string ) (net.Conn , error ) {
969+ t .Error ("NetDial should not be called" )
970+ t .FailNow ()
971+ return nil , nil
972+ },
973+ netDialContext : func (_ context.Context , network , addr string ) (net.Conn , error ) {
974+ return net .Dial (network , addr )
975+ },
976+ netDialTLS : func (network , addr string ) (net.Conn , error ) {
977+ t .Error ("NetDialTLS should not be called" )
978+ t .FailNow ()
979+ return nil , nil
980+ },
981+ netDialTLSContext : func (_ context.Context , network , addr string ) (net.Conn , error ) {
982+ t .Error ("NetDialTLSContext should not be called" )
983+ t .FailNow ()
984+ return nil , nil
985+ },
986+ tlsClientConfig : nil ,
987+ },
988+ "HTTP server, all NetDial* undefined" : {
989+ server : server ,
990+ netDial : nil ,
991+ netDialContext : nil ,
992+ netDialTLS : nil ,
993+ netDialTLSContext : nil ,
994+ tlsClientConfig : nil ,
995+ },
996+ "HTTP server, NetDialContext undefined, shall fallback to NetDial" : {
997+ server : server ,
998+ netDial : func (network , addr string ) (net.Conn , error ) {
999+ return net .Dial (network , addr )
1000+ },
1001+ netDialContext : nil ,
1002+ netDialTLS : func (network , addr string ) (net.Conn , error ) {
1003+ t .Error ("NetDialTLS should not be called" )
1004+ t .FailNow ()
1005+ return nil , nil
1006+ },
1007+ netDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1008+ t .Error ("NetDialTLSContext should not be called" )
1009+ t .FailNow ()
1010+ return nil , nil
1011+ },
1012+ tlsClientConfig : nil ,
1013+ },
1014+ "HTTPS server, all NetDial* defined, shall use NetDialTLSContext" : {
1015+ server : tlsServer ,
1016+ netDial : func (network , addr string ) (net.Conn , error ) {
1017+ t .Error ("NetDial should not be called" )
1018+ t .FailNow ()
1019+ return nil , nil
1020+ },
1021+ netDialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1022+ t .Error ("NetDialContext should not be called" )
1023+ t .FailNow ()
1024+ return nil , nil
1025+ },
1026+ netDialTLS : func (network , addr string ) (net.Conn , error ) {
1027+ t .Error ("NetDialTLS should not be called" )
1028+ t .FailNow ()
1029+ return nil , nil
1030+ },
1031+ netDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1032+ netConn , err := net .Dial (network , addr )
1033+ if err != nil {
1034+ return nil , err
1035+ }
1036+ tlsConn := tls .Client (netConn , tlsConfig )
1037+ err = tlsConn .Handshake ()
1038+ if err != nil {
1039+ return nil , err
1040+ }
1041+ return tlsConn , nil
1042+ },
1043+ tlsClientConfig : nil ,
1044+ },
1045+ "HTTPS server, NetDialTLSContext undefined, shall fallback to NetTLSDial" : {
1046+ server : tlsServer ,
1047+ netDial : func (network , addr string ) (net.Conn , error ) {
1048+ t .Error ("NetDial should not be called" )
1049+ t .FailNow ()
1050+ return nil , nil
1051+ },
1052+ netDialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1053+ t .Error ("NetDialContext should not be called" )
1054+ t .FailNow ()
1055+ return nil , nil
1056+ },
1057+ netDialTLS : func (network , addr string ) (net.Conn , error ) {
1058+ netConn , err := net .Dial (network , addr )
1059+ if err != nil {
1060+ return nil , err
1061+ }
1062+ tlsConn := tls .Client (netConn , tlsConfig )
1063+ err = tlsConn .Handshake ()
1064+ if err != nil {
1065+ return nil , err
1066+ }
1067+ return tlsConn , nil
1068+ },
1069+ netDialTLSContext : nil ,
1070+ tlsClientConfig : nil ,
1071+ },
1072+ "HTTPS server, NetDialTLS* undefined, shall fallback to NetDialContext and do handshake" : {
1073+ server : tlsServer ,
1074+ netDial : func (network , addr string ) (net.Conn , error ) {
1075+ t .Error ("NetDial should not be called" )
1076+ t .FailNow ()
1077+ return nil , nil
1078+ },
1079+ netDialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1080+ return net .Dial (network , addr )
1081+ },
1082+ netDialTLS : nil ,
1083+ netDialTLSContext : nil ,
1084+ tlsClientConfig : tlsConfig ,
1085+ },
1086+ "HTTPS server, NetDialTLS* and NetDialContext undefined, shall fallback to NetDial and do handshake" : {
1087+ server : tlsServer ,
1088+ netDial : func (network , addr string ) (net.Conn , error ) {
1089+ return net .Dial (network , addr )
1090+ },
1091+ netDialContext : nil ,
1092+ netDialTLS : nil ,
1093+ netDialTLSContext : nil ,
1094+ tlsClientConfig : tlsConfig ,
1095+ },
1096+ "HTTPS server, all NetDial* undefined" : {
1097+ server : tlsServer ,
1098+ netDial : nil ,
1099+ netDialContext : nil ,
1100+ netDialTLS : nil ,
1101+ netDialTLSContext : nil ,
1102+ tlsClientConfig : tlsConfig ,
1103+ },
1104+ "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake" : {
1105+ server : tlsServer ,
1106+ netDial : func (network , addr string ) (net.Conn , error ) {
1107+ t .Error ("NetDial should not be called" )
1108+ t .FailNow ()
1109+ return nil , nil
1110+ },
1111+ netDialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1112+ t .Error ("NetDialContext should not be called" )
1113+ t .FailNow ()
1114+ return nil , nil
1115+ },
1116+ netDialTLS : func (network , addr string ) (net.Conn , error ) {
1117+ t .Error ("NetDialTLS should not be called" )
1118+ t .FailNow ()
1119+ return nil , nil
1120+ },
1121+ netDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1122+ netConn , err := net .Dial (network , addr )
1123+ if err != nil {
1124+ return nil , err
1125+ }
1126+ tlsConn := tls .Client (netConn , tlsConfig )
1127+ err = tlsConn .Handshake ()
1128+ if err != nil {
1129+ return nil , err
1130+ }
1131+ return tlsConn , nil
1132+ },
1133+ tlsClientConfig : & tls.Config {
1134+ RootCAs : nil ,
1135+ ServerName : "badserver.com" ,
1136+ InsecureSkipVerify : false ,
1137+ },
1138+ },
1139+ "HTTPS server, NetDialTLS defined, dummy TlsClientConfig defined, shall not do handshake" : {
1140+ server : tlsServer ,
1141+ netDial : func (network , addr string ) (net.Conn , error ) {
1142+ t .Error ("NetDial should not be called" )
1143+ t .FailNow ()
1144+ return nil , nil
1145+ },
1146+ netDialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1147+ t .Error ("NetDialContext should not be called" )
1148+ t .FailNow ()
1149+ return nil , nil
1150+ },
1151+ netDialTLS : func (network , addr string ) (net.Conn , error ) {
1152+ netConn , err := net .Dial (network , addr )
1153+ if err != nil {
1154+ return nil , err
1155+ }
1156+ tlsConn := tls .Client (netConn , tlsConfig )
1157+ err = tlsConn .Handshake ()
1158+ if err != nil {
1159+ return nil , err
1160+ }
1161+ return tlsConn , nil
1162+ },
1163+ netDialTLSContext : nil ,
1164+ tlsClientConfig : & tls.Config {
1165+ RootCAs : nil ,
1166+ ServerName : "badserver.com" ,
1167+ InsecureSkipVerify : false ,
1168+ },
1169+ },
1170+ }
1171+
1172+ for name , tc := range tests {
1173+ dialer := Dialer {
1174+ NetDial : tc .netDial ,
1175+ NetDialContext : tc .netDialContext ,
1176+ NetDialTLS : tc .netDialTLS ,
1177+ NetDialTLSContext : tc .netDialTLSContext ,
1178+ TLSClientConfig : tc .tlsClientConfig ,
1179+ }
1180+
1181+ // Test websocket dial
1182+ c , _ , err := dialer .Dial (testUrls [tc .server ], nil )
1183+ if err != nil {
1184+ t .Errorf ("FAILED %s, err: %s" , name , err .Error ())
1185+ } else {
1186+ c .Close ()
1187+ }
1188+ }
1189+ }
0 commit comments