Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ var nilDialer = *DefaultDialer
// non-nil *http.Response so that callers can handle redirects, authentication,
// etcetera. The response body may not contain the entire response and does not
// need to be closed by the application.
var maxErrorResponseSize = 4096

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The location of this variable declaration breaks the documentation (the DialContext documentation is no longer associated with the method). Move the var declaration.

While you are at it, change it to a const.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup indeed you're right there !


func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
if d == nil {
d = &nilDialer
Expand Down Expand Up @@ -364,9 +366,13 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
// debugging.
buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf)
resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))

limReader := io.LimitReader(resp.Body, int64(maxErrorResponseSize))
Copy link

@BeatieWolfe BeatieWolfe Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete the conversion to int64. The conversion is not required if maxErrorResponseSize is changed from a variable to a constant.

buf, err := io.ReadAll(limReader)
if err != nil && err != io.EOF {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

io.ReadAll never returns io.EOF. Use if err != nil {.

buf = []byte{}
}
resp.Body = io.NopCloser(bytes.NewReader(buf))
return nil, resp, ErrBadHandshake
}

Expand Down
72 changes: 72 additions & 0 deletions client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,10 @@ func TestHandshake(t *testing.T) {
}

func TestRespOnBadHandshake(t *testing.T) {
// Test Body smaller than maxErrorResponseSize.
const expectedStatus = http.StatusGone
const expectedBody = "This is the response body."
const maxErrorResponseSize = 4096

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(expectedStatus)
Expand Down Expand Up @@ -604,6 +606,76 @@ func TestRespOnBadHandshake(t *testing.T) {
if string(p) != expectedBody {
t.Errorf("resp.Body=%s, want %s", p, expectedBody)
}

// Test Body larger than maxErrorResponseSize.
t.Run("ErrorResponseSizeLimited", func(t *testing.T) {
largeBody := make([]byte, maxErrorResponseSize+100)
for i := range largeBody {
largeBody[i] = 'a'
}

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write(largeBody)
}))
defer s.Close()

ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
if err == nil {
ws.Close()
t.Fatalf("Dial: expected error, got nil")
}

if resp == nil {
t.Fatalf("resp=nil, err=%v", err)
}

p, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ReadAll(resp.Body) returned error %v", err)
}

resp.Body.Close()

if len(p) > maxErrorResponseSize {
t.Fatalf("body size=%d, want <= %d", len(p), maxErrorResponseSize)
}
})

// Test Body exactly maxErrorResponseSize.
t.Run("ErrorResponseSizeExactLimit", func(t *testing.T) {
limitedBody := make([]byte, maxErrorResponseSize)
for i := range limitedBody {
limitedBody[i] = 'a'
}

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
w.Write(limitedBody)
}))
defer s.Close()

ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
if err == nil {
ws.Close()
t.Fatalf("Dial: expected error, got nil")
}

if resp == nil {
t.Fatalf("resp=nil, err=%v", err)
}

p, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ReadAll(resp.Body) returned error %v", err)
}

resp.Body.Close()

if len(p) != maxErrorResponseSize {
t.Fatalf("body size=%d, want %d", len(p), maxErrorResponseSize)
}
})
Copy link

@BeatieWolfe BeatieWolfe Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eliminate the duplicated test by using table driven test.

for _, t := range []struct{ size, expect int }{{maxErrorResponseSize + 100, maxErrorResponseSize}, {maxErrorResponseSize, maxErrorResponseSize}} {
        body := bytes.Repeat([]byte{'a'}, t.size)
        ...
		if len(p) != t.expected {
			t.Fatalf("body size=%d, want %d for original body size %d", len(p), t.expected, t.size)
		}
 }

}

type testLogWriter struct {
Expand Down