@@ -11,8 +11,10 @@ import (
1111 "crypto/x509"
1212 "encoding/base64"
1313 "encoding/binary"
14+ "fmt"
1415 "io"
1516 "io/ioutil"
17+ "log"
1618 "net"
1719 "net/http"
1820 "net/http/cookiejar"
@@ -42,17 +44,12 @@ var cstDialer = Dialer{
4244 HandshakeTimeout : 30 * time .Second ,
4345}
4446
45- var cstDialerWithoutHandshakeTimeout = Dialer {
46- Subprotocols : []string {"p1" , "p2" },
47- ReadBufferSize : 1024 ,
48- WriteBufferSize : 1024 ,
49- }
50-
5147type cstHandler struct { * testing.T }
5248
5349type cstServer struct {
5450 * httptest.Server
5551 URL string
52+ t * testing.T
5653}
5754
5855const (
@@ -288,10 +285,7 @@ func TestDialCookieJar(t *testing.T) {
288285 sendRecv (t , ws )
289286}
290287
291- func TestDialTLS (t * testing.T ) {
292- s := newTLSServer (t )
293- defer s .Close ()
294-
288+ func rootCAs (t * testing.T , s * httptest.Server ) * x509.CertPool {
295289 certs := x509 .NewCertPool ()
296290 for _ , c := range s .TLS .Certificates {
297291 roots , err := x509 .ParseCertificates (c .Certificate [len (c .Certificate )- 1 ])
@@ -302,35 +296,15 @@ func TestDialTLS(t *testing.T) {
302296 certs .AddCert (root )
303297 }
304298 }
305-
306- d := cstDialer
307- d .TLSClientConfig = & tls.Config {RootCAs : certs }
308- ws , _ , err := d .Dial (s .URL , nil )
309- if err != nil {
310- t .Fatalf ("Dial: %v" , err )
311- }
312- defer ws .Close ()
313- sendRecv (t , ws )
314- }
315-
316- func xTestDialTLSBadCert (t * testing.T ) {
317- // This test is deactivated because of noisy logging from the net/http package.
318- s := newTLSServer (t )
319- defer s .Close ()
320-
321- ws , _ , err := cstDialer .Dial (s .URL , nil )
322- if err == nil {
323- ws .Close ()
324- t .Fatalf ("Dial: nil" )
325- }
299+ return certs
326300}
327301
328- func TestDialTLSNoVerify (t * testing.T ) {
302+ func TestDialTLS (t * testing.T ) {
329303 s := newTLSServer (t )
330304 defer s .Close ()
331305
332306 d := cstDialer
333- d .TLSClientConfig = & tls.Config {InsecureSkipVerify : true }
307+ d .TLSClientConfig = & tls.Config {RootCAs : rootCAs ( t , s . Server ) }
334308 ws , _ , err := d .Dial (s .URL , nil )
335309 if err != nil {
336310 t .Fatalf ("Dial: %v" , err )
@@ -415,7 +389,8 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
415389 s := newServer (t )
416390 defer s .Close ()
417391
418- d := cstDialerWithoutHandshakeTimeout
392+ d := cstDialer
393+ d .HandshakeTimeout = 0
419394 d .NetDialContext = func (ctx context.Context , n , a string ) (net.Conn , error ) {
420395 netDialer := & net.Dialer {}
421396 c , err := netDialer .DialContext (ctx , n , a )
@@ -566,33 +541,195 @@ func TestRespOnBadHandshake(t *testing.T) {
566541 }
567542}
568543
569- // TestHostHeader confirms that the host header provided in the call to Dial is
570- // sent to the server.
571- func TestHostHeader (t * testing.T ) {
572- s := newServer (t )
573- defer s .Close ()
544+ type testLogWriter struct {
545+ t * testing.T
546+ }
574547
575- specifiedHost := make (chan string , 1 )
576- origHandler := s .Server .Config .Handler
548+ func (w testLogWriter ) Write (p []byte ) (int , error ) {
549+ w .t .Logf ("%s" , p )
550+ return len (p ), nil
551+ }
577552
578- // Capture the request Host header.
579- s .Server .Config .Handler = http .HandlerFunc (
580- func (w http.ResponseWriter , r * http.Request ) {
581- specifiedHost <- r .Host
582- origHandler .ServeHTTP (w , r )
583- })
553+ // TestHost tests handling of host names and confirms that it matches net/http.
554+ func TestHost (t * testing.T ) {
584555
585- ws , _ , err := cstDialer .Dial (s .URL , http.Header {"Host" : {"testhost" }})
586- if err != nil {
587- t .Fatalf ("Dial: %v" , err )
588- }
589- defer ws .Close ()
556+ upgrader := Upgrader {}
557+ handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
558+ if IsWebSocketUpgrade (r ) {
559+ c , err := upgrader .Upgrade (w , r , http.Header {"X-Test-Host" : {r .Host }})
560+ if err != nil {
561+ t .Fatal (err )
562+ }
563+ c .Close ()
564+ } else {
565+ w .Header ().Set ("X-Test-Host" , r .Host )
566+ }
567+ })
568+
569+ server := httptest .NewServer (handler )
570+ defer server .Close ()
571+
572+ tlsServer := httptest .NewTLSServer (handler )
573+ defer tlsServer .Close ()
574+
575+ addrs := map [* httptest.Server ]string {server : server .Listener .Addr ().String (), tlsServer : tlsServer .Listener .Addr ().String ()}
576+ wsProtos := map [* httptest.Server ]string {server : "ws://" , tlsServer : "wss://" }
577+ httpProtos := map [* httptest.Server ]string {server : "http://" , tlsServer : "https://" }
578+
579+ // Avoid log noise from net/http server by logging to testing.T
580+ server .Config .ErrorLog = log .New (testLogWriter {t }, "" , 0 )
581+ tlsServer .Config .ErrorLog = server .Config .ErrorLog
582+
583+ cas := rootCAs (t , tlsServer )
584+
585+ tests := []struct {
586+ fail bool // true if dial / get should fail
587+ server * httptest.Server // server to use
588+ url string // host for request URI
589+ header string // optional request host header
590+ tls string // optiona host for tls ServerName
591+ wantAddr string // expected host for dial
592+ wantHeader string // expected request header on server
593+ insecureSkipVerify bool
594+ }{
595+ {
596+ server : server ,
597+ url : addrs [server ],
598+ wantAddr : addrs [server ],
599+ wantHeader : addrs [server ],
600+ },
601+ {
602+ server : tlsServer ,
603+ url : addrs [tlsServer ],
604+ wantAddr : addrs [tlsServer ],
605+ wantHeader : addrs [tlsServer ],
606+ },
607+
608+ {
609+ server : server ,
610+ url : addrs [server ],
611+ header : "badhost.com" ,
612+ wantAddr : addrs [server ],
613+ wantHeader : "badhost.com" ,
614+ },
615+ {
616+ server : tlsServer ,
617+ url : addrs [tlsServer ],
618+ header : "badhost.com" ,
619+ wantAddr : addrs [tlsServer ],
620+ wantHeader : "badhost.com" ,
621+ },
622+
623+ {
624+ server : server ,
625+ url : "example.com" ,
626+ header : "badhost.com" ,
627+ wantAddr : "example.com:80" ,
628+ wantHeader : "badhost.com" ,
629+ },
630+ {
631+ server : tlsServer ,
632+ url : "example.com" ,
633+ header : "badhost.com" ,
634+ wantAddr : "example.com:443" ,
635+ wantHeader : "badhost.com" ,
636+ },
590637
591- if gotHost := <- specifiedHost ; gotHost != "testhost" {
592- t .Fatalf ("gotHost = %q, want \" testhost\" " , gotHost )
638+ {
639+ server : server ,
640+ url : "badhost.com" ,
641+ header : "example.com" ,
642+ wantAddr : "badhost.com:80" ,
643+ wantHeader : "example.com" ,
644+ },
645+ {
646+ fail : true ,
647+ server : tlsServer ,
648+ url : "badhost.com" ,
649+ header : "example.com" ,
650+ wantAddr : "badhost.com:443" ,
651+ },
652+ {
653+ server : tlsServer ,
654+ url : "badhost.com" ,
655+ insecureSkipVerify : true ,
656+ wantAddr : "badhost.com:443" ,
657+ wantHeader : "badhost.com" ,
658+ },
659+ {
660+ server : tlsServer ,
661+ url : "badhost.com" ,
662+ tls : "example.com" ,
663+ wantAddr : "badhost.com:443" ,
664+ wantHeader : "badhost.com" ,
665+ },
593666 }
594667
595- sendRecv (t , ws )
668+ for i , tt := range tests {
669+
670+ tls := & tls.Config {
671+ RootCAs : cas ,
672+ ServerName : tt .tls ,
673+ InsecureSkipVerify : tt .insecureSkipVerify ,
674+ }
675+
676+ var gotAddr string
677+ dialer := Dialer {
678+ NetDial : func (network , addr string ) (net.Conn , error ) {
679+ gotAddr = addr
680+ return net .Dial (network , addrs [tt .server ])
681+ },
682+ TLSClientConfig : tls ,
683+ }
684+
685+ // Test websocket dial
686+
687+ h := http.Header {}
688+ if tt .header != "" {
689+ h .Set ("Host" , tt .header )
690+ }
691+ c , resp , err := dialer .Dial (wsProtos [tt .server ]+ tt .url + "/" , h )
692+ if err == nil {
693+ c .Close ()
694+ }
695+
696+ check := func (protos map [* httptest.Server ]string ) {
697+ name := fmt .Sprintf ("%d: %s%s/ header[Host]=%q, tls.ServerName=%q" , i + 1 , protos [tt .server ], tt .url , tt .header , tt .tls )
698+ if gotAddr != tt .wantAddr {
699+ t .Errorf ("%s: got addr %s, want %s" , name , gotAddr , tt .wantAddr )
700+ }
701+ switch {
702+ case tt .fail && err == nil :
703+ t .Errorf ("%s: unexpected success" , name )
704+ case ! tt .fail && err != nil :
705+ t .Errorf ("%s: unexpected error %v" , name , err )
706+ case ! tt .fail && err == nil :
707+ if gotHost := resp .Header .Get ("X-Test-Host" ); gotHost != tt .wantHeader {
708+ t .Errorf ("%s: got host %s, want %s" , name , gotHost , tt .wantHeader )
709+ }
710+ }
711+ }
712+
713+ check (wsProtos )
714+
715+ // Confirm that net/http has same result
716+
717+ transport := & http.Transport {
718+ Dial : dialer .NetDial ,
719+ TLSClientConfig : dialer .TLSClientConfig ,
720+ }
721+ req , _ := http .NewRequest ("GET" , httpProtos [tt .server ]+ tt .url + "/" , nil )
722+ if tt .header != "" {
723+ req .Host = tt .header
724+ }
725+ client := & http.Client {Transport : transport }
726+ resp , err = client .Do (req )
727+ if err == nil {
728+ resp .Body .Close ()
729+ }
730+ transport .CloseIdleConnections ()
731+ check (httpProtos )
732+ }
596733}
597734
598735func TestDialCompression (t * testing.T ) {
@@ -716,19 +853,8 @@ func TestTracingDialWithContext(t *testing.T) {
716853 s := newTLSServer (t )
717854 defer s .Close ()
718855
719- certs := x509 .NewCertPool ()
720- for _ , c := range s .TLS .Certificates {
721- roots , err := x509 .ParseCertificates (c .Certificate [len (c .Certificate )- 1 ])
722- if err != nil {
723- t .Fatalf ("error parsing server's root cert: %v" , err )
724- }
725- for _ , root := range roots {
726- certs .AddCert (root )
727- }
728- }
729-
730856 d := cstDialer
731- d .TLSClientConfig = & tls.Config {RootCAs : certs }
857+ d .TLSClientConfig = & tls.Config {RootCAs : rootCAs ( t , s . Server ) }
732858
733859 ws , _ , err := d .DialContext (ctx , s .URL , nil )
734860 if err != nil {
@@ -766,19 +892,8 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
766892 s := newTLSServer (t )
767893 defer s .Close ()
768894
769- certs := x509 .NewCertPool ()
770- for _ , c := range s .TLS .Certificates {
771- roots , err := x509 .ParseCertificates (c .Certificate [len (c .Certificate )- 1 ])
772- if err != nil {
773- t .Fatalf ("error parsing server's root cert: %v" , err )
774- }
775- for _ , root := range roots {
776- certs .AddCert (root )
777- }
778- }
779-
780895 d := cstDialer
781- d .TLSClientConfig = & tls.Config {RootCAs : certs }
896+ d .TLSClientConfig = & tls.Config {RootCAs : rootCAs ( t , s . Server ) }
782897
783898 ws , _ , err := d .DialContext (ctx , s .URL , nil )
784899 if err != nil {
0 commit comments