@@ -6,12 +6,14 @@ package websocket
66
77import (
88 "bytes"
9+ "context"
910 "crypto/tls"
1011 "errors"
1112 "io"
1213 "io/ioutil"
1314 "net"
1415 "net/http"
16+ "net/http/httptrace"
1517 "net/url"
1618 "strings"
1719 "time"
@@ -51,6 +53,10 @@ type Dialer struct {
5153 // NetDial is nil, net.Dial is used.
5254 NetDial func (network , addr string ) (net.Conn , error )
5355
56+ // NetDialContext specifies the dial function for creating TCP connections. If
57+ // NetDialContext is nil, net.DialContext is used.
58+ NetDialContext func (ctx context.Context , network , addr string ) (net.Conn , error )
59+
5460 // Proxy specifies a function to return a proxy for a given
5561 // Request. If the function returns a non-nil error, the
5662 // request is aborted with the provided error.
@@ -95,6 +101,11 @@ type Dialer struct {
95101 Jar http.CookieJar
96102}
97103
104+ // Dial creates a new client connection by calling DialContext with a background context.
105+ func (d * Dialer ) Dial (urlStr string , requestHeader http.Header ) (* Conn , * http.Response , error ) {
106+ return d .DialContext (urlStr , requestHeader , context .Background ())
107+ }
108+
98109var errMalformedURL = errors .New ("malformed ws or wss URL" )
99110
100111func hostPortNoPort (u * url.URL ) (hostPort , hostNoPort string ) {
@@ -124,17 +135,18 @@ var DefaultDialer = &Dialer{
124135// nilDialer is dialer to use when receiver is nil.
125136var nilDialer Dialer = * DefaultDialer
126137
127- // Dial creates a new client connection. Use requestHeader to specify the
138+ // DialContext creates a new client connection. Use requestHeader to specify the
128139// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
129140// Use the response.Header to get the selected subprotocol
130141// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
131142//
143+ // The context will be used in the request and in the Dialer
144+ //
132145// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
133146// non-nil *http.Response so that callers can handle redirects, authentication,
134147// etcetera. The response body may not contain the entire response and does not
135148// need to be closed by the application.
136- func (d * Dialer ) Dial (urlStr string , requestHeader http.Header ) (* Conn , * http.Response , error ) {
137-
149+ func (d * Dialer ) DialContext (urlStr string , requestHeader http.Header , ctx context.Context ) (* Conn , * http.Response , error ) {
138150 if d == nil {
139151 d = & nilDialer
140152 }
@@ -172,6 +184,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
172184 Header : make (http.Header ),
173185 Host : u .Host ,
174186 }
187+ req = req .WithContext (ctx )
175188
176189 // Set the cookies present in the cookie jar of the dialer
177190 if d .Jar != nil {
@@ -215,20 +228,30 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
215228 req .Header ["Sec-WebSocket-Extensions" ] = []string {"permessage-deflate; server_no_context_takeover; client_no_context_takeover" }
216229 }
217230
218- var deadline time.Time
219231 if d .HandshakeTimeout != 0 {
220- deadline = time .Now ().Add (d .HandshakeTimeout )
232+ var cancel func ()
233+ ctx , cancel = context .WithTimeout (ctx , d .HandshakeTimeout )
234+ defer cancel ()
221235 }
222236
223237 // Get network dial function.
224- netDial := d .NetDial
225- if netDial == nil {
226- netDialer := & net.Dialer {Deadline : deadline }
227- netDial = netDialer .Dial
238+ var netDial func (network , add string ) (net.Conn , error )
239+
240+ if d .NetDialContext != nil {
241+ netDial = func (network , addr string ) (net.Conn , error ) {
242+ return d .NetDialContext (ctx , network , addr )
243+ }
244+ } else if d .NetDial != nil {
245+ netDial = d .NetDial
246+ } else {
247+ netDialer := & net.Dialer {}
248+ netDial = func (network , addr string ) (net.Conn , error ) {
249+ return netDialer .DialContext (ctx , network , addr )
250+ }
228251 }
229252
230253 // If needed, wrap the dial function to set the connection deadline.
231- if ! deadline . Equal (time. Time {}) {
254+ if deadline , ok := ctx . Deadline (); ok {
232255 forwardDial := netDial
233256 netDial = func (network , addr string ) (net.Conn , error ) {
234257 c , err := forwardDial (network , addr )
@@ -260,7 +283,17 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
260283 }
261284
262285 hostPort , hostNoPort := hostPortNoPort (u )
286+ trace := httptrace .ContextClientTrace (ctx )
287+ if trace != nil && trace .GetConn != nil {
288+ trace .GetConn (hostPort )
289+ }
290+
263291 netConn , err := netDial ("tcp" , hostPort )
292+ if trace != nil && trace .GotConn != nil {
293+ trace .GotConn (httptrace.GotConnInfo {
294+ Conn : netConn ,
295+ })
296+ }
264297 if err != nil {
265298 return nil , nil , err
266299 }
@@ -278,13 +311,16 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
278311 }
279312 tlsConn := tls .Client (netConn , cfg )
280313 netConn = tlsConn
281- if err := tlsConn .Handshake (); err != nil {
282- return nil , nil , err
314+
315+ var err error
316+ if trace != nil {
317+ err = doHandshakeWithTrace (trace , tlsConn , cfg )
318+ } else {
319+ err = doHandshake (tlsConn , cfg )
283320 }
284- if ! cfg .InsecureSkipVerify {
285- if err := tlsConn .VerifyHostname (cfg .ServerName ); err != nil {
286- return nil , nil , err
287- }
321+
322+ if err != nil {
323+ return nil , nil , err
288324 }
289325 }
290326
@@ -294,6 +330,12 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
294330 return nil , nil , err
295331 }
296332
333+ if trace != nil && trace .GotFirstResponseByte != nil {
334+ if peek , err := conn .br .Peek (1 ); err == nil && len (peek ) == 1 {
335+ trace .GotFirstResponseByte ()
336+ }
337+ }
338+
297339 resp , err := http .ReadResponse (conn .br , req )
298340 if err != nil {
299341 return nil , nil , err
@@ -339,3 +381,15 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
339381 netConn = nil // to avoid close in defer.
340382 return conn , resp , nil
341383}
384+
385+ func doHandshake (tlsConn * tls.Conn , cfg * tls.Config ) error {
386+ if err := tlsConn .Handshake (); err != nil {
387+ return err
388+ }
389+ if ! cfg .InsecureSkipVerify {
390+ if err := tlsConn .VerifyHostname (cfg .ServerName ); err != nil {
391+ return err
392+ }
393+ }
394+ return nil
395+ }
0 commit comments