Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
go: [ '1.15', '1.14' ]
go: [ '1.15' ]
steps:
- uses: actions/checkout@v2
- name: Set up Go
Expand Down
56 changes: 50 additions & 6 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,19 @@ import (
"time"
)

// DefaultReadHeaderTimeout is how long header processing waits for header to
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
// It's kept as a global variable so to make it easier to find and override,
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
var DefaultReadHeaderTimeout = 200 * time.Millisecond

// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol.
// If the connection is using the protocol, the RemoteAddr() will return
// the correct client address.
// the correct client address. ReadHeaderTimeout will be applied to all
// connections in order to prevent blocking operations. If no ReadHeaderTimeout
// is set, a default of 200ms will be used. This can be disabled by setting the
// timeout to < 0.
type Listener struct {
Listener net.Listener
Policy PolicyFunc
Expand All @@ -21,7 +30,8 @@ type Listener struct {

// Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address.
// return the address of the client instead of the proxy address. Each connection
// will have its own readHeaderTimeout and readDeadline set by the Accept() call.
type Conn struct {
bufReader *bufio.Reader
conn net.Conn
Expand All @@ -30,6 +40,8 @@ type Conn struct {
ProxyHeaderPolicy Policy
Validate Validator
readErr error
readHeaderTimeout time.Duration
readDeadline time.Time
}

// Validator receives a header and decides whether it is a valid one
Expand All @@ -53,10 +65,6 @@ func (p *Listener) Accept() (net.Conn, error) {
return nil, err
}

if d := p.ReadHeaderTimeout; d != 0 {
conn.SetReadDeadline(time.Now().Add(d))
}

proxyHeaderPolicy := USE
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
Expand All @@ -72,6 +80,15 @@ func (p *Listener) Accept() (net.Conn, error) {
WithPolicy(proxyHeaderPolicy),
ValidateHeader(p.ValidateHeader),
)

// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
if p.ReadHeaderTimeout == 0 {
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
}

// Set the readHeaderTimeout of the new conn to the value of the listener
newConn.readHeaderTimeout = p.ReadHeaderTimeout

return newConn, nil
}

Expand Down Expand Up @@ -110,6 +127,7 @@ func (p *Conn) Read(b []byte) (int, error) {
if p.readErr != nil {
return 0, p.readErr
}

return p.bufReader.Read(b)
}

Expand Down Expand Up @@ -197,11 +215,16 @@ func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) {

// SetDeadline wraps original conn.SetDeadline
func (p *Conn) SetDeadline(t time.Time) error {
p.readDeadline = t
return p.conn.SetDeadline(t)
}

// SetReadDeadline wraps original conn.SetReadDeadline
func (p *Conn) SetReadDeadline(t time.Time) error {
// Set a local var that tells us the desired deadline. This is
// needed in order to reset the read deadline to the one that is
// desired by the user, rather than an empty deadline.
p.readDeadline = t
return p.conn.SetReadDeadline(t)
}

Expand All @@ -211,7 +234,28 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
}

func (p *Conn) readHeader() error {
// If the connection's readHeaderTimeout is more than 0,
// push our deadline back to now plus the timeout. This should only
// run on the connection, as we don't want to override the previous
// read deadline the user may have used.
if p.readHeaderTimeout > 0 {
p.conn.SetReadDeadline(time.Now().Add(p.readHeaderTimeout))
}

header, err := Read(p.bufReader)

// If the connection's readHeaderTimeout is more than 0, undo the change to the
// deadline that we made above. Because we retain the readDeadline as part of our
// SetReadDeadline override, we know the user's desired deadline so we use that.
// Therefore, we check whether the error is a net.Timeout and if it is, we decide
// the proxy proto does not exist and set the error accordingly.
if p.readHeaderTimeout > 0 {
p.conn.SetReadDeadline(p.readDeadline)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
err = ErrNoProxyProtocol
}
}

// For the purpose of this wrapper shamefully stolen from armon/go-proxyproto
// let's act as if there was no error when PROXY protocol is not present.
if err == ErrNoProxyProtocol {
Expand Down
Loading