Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ type Conn struct {
readErr error
conn net.Conn
bufReader *bufio.Reader
reader io.Reader
header *Header
ProxyHeaderPolicy Policy
Validate Validator
Expand Down Expand Up @@ -171,7 +170,6 @@ func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {

pConn := &Conn{
bufReader: br,
reader: io.MultiReader(br, conn),
conn: conn,
}

Expand All @@ -191,7 +189,31 @@ func (p *Conn) Read(b []byte) (int, error) {
return 0, err
}

return p.reader.Read(b)
// Drain the buffer if it exists and has data.
if p.bufReader != nil {
if p.bufReader.Buffered() > 0 {
n, err := p.bufReader.Read(b)

// Did we empty the buffer?
// Buffering a net.Conn means the buffer doesn't return io.EOF until the connection returns io.EOF.
// Therefore, we use Buffered() == 0 to detect if we are done with the buffer.
if p.bufReader.Buffered() == 0 {
// Garbage collect the buffer.
p.bufReader = nil
}

// Return immediately. Do not touch p.conn.
// If err is EOF here, it means the connection is actually closed,
// so we should return that error to the user anyway.
return n, err
}
// If buffer was empty to begin with (shouldn't happen with the >0 check
// but good for safety), clear it.
p.bufReader = nil
}

// From now on, read directly from the underlying connection.
return p.conn.Read(b)
}

// Write wraps original conn.Write.
Expand Down Expand Up @@ -411,6 +433,11 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) {
return 0, err
}

// If the buffer has been drained (or cleared), copy directly from conn.
if p.bufReader == nil {
return io.Copy(w, p.conn)
}

b := make([]byte, p.bufReader.Buffered())
if _, err := p.bufReader.Read(b); err != nil {
return 0, err // this should never happen as we read buffered data.
Expand Down
172 changes: 172 additions & 0 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2371,6 +2371,178 @@ func TestWriteToDrainsBufferedData(t *testing.T) {
}
}

// chunkedConn wraps a net.Conn and limits reads to simulate TCP chunking.
type chunkedConn struct {
net.Conn
maxRead int
readCalls int
bytesRead int
}

func (c *chunkedConn) Read(b []byte) (int, error) {
if len(b) > c.maxRead {
b = b[:c.maxRead]
}
n, err := c.Conn.Read(b)
if n > 0 {
c.readCalls++
c.bytesRead += n
}
return n, err
}

// TestConnReadHandlesChunkedPayload verifies Conn.Read does not drop data
// when the initial TCP read is smaller than the payload.
func TestConnReadHandlesChunkedPayload(t *testing.T) {
const payloadSize = 400

proxyHeader := []byte("PROXY TCP4 192.168.1.1 192.168.1.2 12345 443\r\n")
payload := bytes.Repeat([]byte("X"), payloadSize)
fullData := append(proxyHeader, payload...)

serverConn, clientConn := net.Pipe()
defer func() {
serverCloseErr := serverConn.Close()
clientCloseErr := clientConn.Close()
if serverCloseErr != nil || clientCloseErr != nil {
t.Errorf("failed to close connection: %v, %v", serverCloseErr, clientCloseErr)
}
}()

go func() {
_, _ = clientConn.Write(fullData)
_ = clientConn.Close()
}()

// Simulate TCP delivering only 256 bytes in first read.
chunked := &chunkedConn{Conn: serverConn, maxRead: 256}

// Create a ProxyProto-wrapped connection.
conn := NewConn(chunked)
buf := make([]byte, 64)
readPayload := make([]byte, 0, payloadSize)
for len(readPayload) < payloadSize {
_ = conn.SetReadDeadline(time.Now().Add(time.Second))
n, err := conn.Read(buf)
if err != nil && err != io.EOF {
t.Fatalf("unexpected read error: %v", err)
}
if n > 0 {
readPayload = append(readPayload, buf[:n]...)
}
if err == io.EOF {
break
}
}
Comment on lines +2424 to +2436
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is also passing with the previous MultiReader approach (although back then we always used the bufferedReader and never directly the conn, but that is also not tested).

If you want to stick with adhering strictly to the reader contract, I would modify this test so that we actually validate that we reach the underlying Conn (I can do that if you let me modify your branch). I.e. so that this test fails on the earlier version of the code. the modification amounts to doing a double-read with a 4KiB buffer and using a payload like 1KiB:

  1. we receive the buffered bytes
  2. we receive the remaining bytes with the following Read() call.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to open a PR against this PR. Since I'm working directly on this repo, I would have to give you write permissions to it so you could write to this branch, which is not ideal.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC the test fails with the earlier version of the code but passes now as documented in #148 (comment).


t.Logf("Sent: %d bytes payload (after %d byte PROXY header)", payloadSize, len(proxyHeader))
t.Logf("Read: %d bytes", len(readPayload))

if len(readPayload) != payloadSize {
t.Fatalf("read %d bytes, expected %d", len(readPayload), payloadSize)
}
if !bytes.Equal(readPayload, payload) {
t.Fatalf("payload mismatch")
}

// Ensure the proxy connection read from the underlying conn
// and drained all bytes, not just buffered reads.
if chunked.readCalls == 0 {
t.Fatalf("expected underlying reads to occur")
}
if chunked.bytesRead <= len(proxyHeader) {
t.Fatalf("expected reads beyond header, got %d bytes", chunked.bytesRead)
}
if chunked.bytesRead != len(fullData) {
t.Fatalf("underlying reads=%d bytes, expected %d", chunked.bytesRead, len(fullData))
}
}

func TestReadUsesConnWhenBufReaderNil(t *testing.T) {
serverConn, clientConn := net.Pipe()
t.Cleanup(func() {
if closeErr := serverConn.Close(); closeErr != nil {
t.Errorf("failed to close server connection: %v", closeErr)
}
})
t.Cleanup(func() {
if closeErr := clientConn.Close(); closeErr != nil {
t.Errorf("failed to close client connection: %v", closeErr)
}
})

proxyConn := NewConn(serverConn)
sendSecond := make(chan struct{})

go func() {
_, _ = clientConn.Write([]byte("a"))
<-sendSecond
_, _ = clientConn.Write([]byte("b"))
_ = clientConn.Close()
}()

buf := make([]byte, 1)
// First read processes header detection and drains the buffer.
if _, err := proxyConn.Read(buf); err != nil {
t.Fatalf("first read failed: %v", err)
}
if proxyConn.bufReader != nil {
t.Fatalf("expected bufReader to be nil after draining buffer")
}

// With bufReader cleared, Read should use the underlying conn.
close(sendSecond)
if _, err := proxyConn.Read(buf); err != nil {
t.Fatalf("second read failed: %v", err)
}
if string(buf) != "b" {
t.Fatalf("unexpected second read payload: %q", string(buf))
}
}

func TestWriteToUsesConnWhenBufReaderNil(t *testing.T) {
serverConn, clientConn := net.Pipe()
t.Cleanup(func() {
if closeErr := serverConn.Close(); closeErr != nil {
t.Errorf("failed to close server connection: %v", closeErr)
}
})
t.Cleanup(func() {
if closeErr := clientConn.Close(); closeErr != nil {
t.Errorf("failed to close client connection: %v", closeErr)
}
})

proxyConn := NewConn(serverConn)
sendPayload := make(chan struct{})

go func() {
_, _ = clientConn.Write([]byte("x"))
<-sendPayload
_, _ = clientConn.Write([]byte("payload"))
_ = clientConn.Close()
}()

// Process header detection and drain the buffer.
buf := make([]byte, 1)
if _, err := proxyConn.Read(buf); err != nil {
t.Fatalf("initial read failed: %v", err)
}
if proxyConn.bufReader != nil {
t.Fatalf("expected bufReader to be nil after draining buffer")
}

// With bufReader cleared, WriteTo should copy directly from conn.
close(sendPayload)
var out bytes.Buffer
if _, err := proxyConn.WriteTo(&out); err != nil {
t.Fatalf("WriteTo failed: %v", err)
}
if out.String() != "payload" {
t.Fatalf("unexpected WriteTo output: %q", out.String())
}
}

func benchmarkTCPProxy(size int, b *testing.B) {
// create and start the echo backend
backend, err := net.Listen("tcp", testLocalhostRandomPort)
Expand Down
Loading