Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions addr_proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,6 @@ const (
UnixDatagram AddressFamilyAndProtocol = '\x32'
)

var supportedTransportProtocol = map[AddressFamilyAndProtocol]bool{
TCPv4: true,
UDPv4: true,
TCPv6: true,
UDPv6: true,
UnixStream: true,
UnixDatagram: true,
}

// IsIPv4 returns true if the address family is IPv4 (AF_INET4), false otherwise.
func (ap AddressFamilyAndProtocol) IsIPv4() bool {
return 0x10 == ap&0xF0
Expand Down
1 change: 1 addition & 0 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var (
SIGV1 = []byte{'\x50', '\x52', '\x4F', '\x58', '\x59'}
SIGV2 = []byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'}

ErrLineMustEndWithCrlf = errors.New("proxyproto: header is invalid, must end with \\r\\n")
ErrCantReadProtocolVersionAndCommand = errors.New("proxyproto: can't read proxy protocol version and command")
ErrCantReadAddressFamilyAndProtocol = errors.New("proxyproto: can't read address family or protocol")
ErrCantReadLength = errors.New("proxyproto: can't read length")
Expand Down
8 changes: 4 additions & 4 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,15 @@ func (p *Conn) RemoteAddr() net.Addr {
// Raw returns the underlying connection which can be casted to
// a concrete type, allowing access to specialized functions.
//
// Use this ONLY if you know exactly what you are doing.
// Use this ONLY if you know exactly what you are doing.
func (p *Conn) Raw() net.Conn {
return p.conn
}

// TCPConn returns the underlying TCP connection,
// allowing access to specialized functions.
//
// Use this ONLY if you know exactly what you are doing.
// Use this ONLY if you know exactly what you are doing.
func (p *Conn) TCPConn() (conn *net.TCPConn, ok bool) {
conn, ok = p.conn.(*net.TCPConn)
return
Expand All @@ -174,7 +174,7 @@ func (p *Conn) TCPConn() (conn *net.TCPConn, ok bool) {
// UnixConn returns the underlying Unix socket connection,
// allowing access to specialized functions.
//
// Use this ONLY if you know exactly what you are doing.
// Use this ONLY if you know exactly what you are doing.
func (p *Conn) UnixConn() (conn *net.UnixConn, ok bool) {
conn, ok = p.conn.(*net.UnixConn)
return
Expand All @@ -183,7 +183,7 @@ func (p *Conn) UnixConn() (conn *net.UnixConn, ok bool) {
// UDPConn returns the underlying UDP connection,
// allowing access to specialized functions.
//
// Use this ONLY if you know exactly what you are doing.
// Use this ONLY if you know exactly what you are doing.
func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) {
conn, ok = p.conn.(*net.UDPConn)
return
Expand Down
97 changes: 59 additions & 38 deletions v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,52 +22,73 @@ func initVersion1() *Header {
}

func parseVersion1(reader *bufio.Reader) (*Header, error) {
// Make sure we have a v1 header
// Read until LF shows up, otherwise fail.
// At this point, can't be sure CR precedes LF which will be validated next.
line, err := reader.ReadString('\n')
if err != nil {
return nil, ErrLineMustEndWithCrlf
}
if !strings.HasSuffix(line, crlf) {
return nil, ErrCantReadProtocolVersionAndCommand
return nil, ErrLineMustEndWithCrlf
}
// Check full signature.
tokens := strings.Split(line[:len(line)-2], separator)
if len(tokens) < 6 {
return nil, ErrCantReadProtocolVersionAndCommand
transportProtocol := UNSPEC // doesn't exist in v1 but fits UNKNOWN.
if len(tokens) > 0 {
// Read address family and protocol
switch tokens[1] {
case "TCP4":
transportProtocol = TCPv4
case "TCP6":
transportProtocol = TCPv6
case "UNKNOWN": // no-op as UNSPEC is set already
default:
return nil, ErrCantReadAddressFamilyAndProtocol
}

// Expect 6 tokens only when UNKNOWN is not present.
if !transportProtocol.IsUnspec() && len(tokens) < 6 {
return nil, ErrCantReadAddressFamilyAndProtocol
}
}

// Allocation only happens when a signature is found.
header := initVersion1()
// If UNKNOWN is present, set Command to LOCAL.
// Command is not present in v1 but set it for other parts of
// this library to rely on it for determining connection details.
header.Command = LOCAL

// Read address family and protocol
switch tokens[1] {
case "TCP4":
header.TransportProtocol = TCPv4
case "TCP6":
header.TransportProtocol = TCPv6
default:
header.TransportProtocol = UNSPEC
}
// Transport protocol has been processed already.
header.TransportProtocol = transportProtocol

// Read addresses and ports
sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2])
if err != nil {
return nil, err
}
destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3])
if err != nil {
return nil, err
}
sourcePort, err := parseV1PortNumber(tokens[4])
if err != nil {
return nil, err
}
destPort, err := parseV1PortNumber(tokens[5])
if err != nil {
return nil, err
}
header.SourceAddr = &net.TCPAddr{
IP: sourceIP,
Port: sourcePort,
}
header.DestinationAddr = &net.TCPAddr{
IP: destIP,
Port: destPort,
// Only process further if UNKNOWN is not present.
if header.TransportProtocol != UNSPEC {
// Read addresses and ports
sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2])
if err != nil {
return nil, err
}
destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3])
if err != nil {
return nil, err
}
sourcePort, err := parseV1PortNumber(tokens[4])
if err != nil {
return nil, err
}
destPort, err := parseV1PortNumber(tokens[5])
if err != nil {
return nil, err
}
header.SourceAddr = &net.TCPAddr{
IP: sourceIP,
Port: sourcePort,
}
header.DestinationAddr = &net.TCPAddr{
IP: destIP,
Port: destPort,
}
}

return header, nil
Expand All @@ -84,7 +105,7 @@ func (header *Header) formatVersion1() ([]byte, error) {
proto = "TCP6"
default:
// Unknown connection (short form)
return []byte("PROXY UNKNOWN\r\n"), nil
return []byte("PROXY UNKNOWN" + crlf), nil
}

sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr)
Expand Down
150 changes: 101 additions & 49 deletions v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,63 +9,88 @@ import (
)

var (
TCP4AddressesAndPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)
TCP4AddressesAndInvalidPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(INVALID_PORT), strconv.Itoa(INVALID_PORT)}, separator)
TCP6AddressesAndPorts = strings.Join([]string{IP6_ADDR, IP6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)
IPv4AddressesAndPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)
IPv4AddressesAndInvalidPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(INVALID_PORT), strconv.Itoa(INVALID_PORT)}, separator)
IPv6AddressesAndPorts = strings.Join([]string{IP6_ADDR, IP6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)

fixtureTCP4V1 = "PROXY TCP4 " + TCP4AddressesAndPorts + crlf + "GET /"
fixtureTCP6V1 = "PROXY TCP6 " + TCP6AddressesAndPorts + crlf + "GET /"
fixtureTCP4V1 = "PROXY TCP4 " + IPv4AddressesAndPorts + crlf + "GET /"
fixtureTCP6V1 = "PROXY TCP6 " + IPv6AddressesAndPorts + crlf + "GET /"

fixtureUnknown = "PROXY UNKNOWN" + crlf
fixtureUnknownWithAddresses = "PROXY UNKNOWN " + IPv4AddressesAndInvalidPorts + crlf
)

var invalidParseV1Tests = []struct {
desc string
reader *bufio.Reader
expectedError error
}{
{
newBufioReader([]byte("PROX")),
ErrNoProxyProtocol,
desc: "no signature",
reader: newBufioReader([]byte(NO_PROTOCOL)),
expectedError: ErrNoProxyProtocol,
},
{
desc: "prox",
reader: newBufioReader([]byte("PROX")),
expectedError: ErrNoProxyProtocol,
},
{
desc: "proxy lf",
reader: newBufioReader([]byte("PROXY \n")),
expectedError: ErrLineMustEndWithCrlf,
},
{
newBufioReader([]byte(NO_PROTOCOL)),
ErrNoProxyProtocol,
desc: "proxy crlf",
reader: newBufioReader([]byte("PROXY " + crlf)),
expectedError: ErrCantReadAddressFamilyAndProtocol,
},
{
newBufioReader([]byte("PROXY \r\n")),
ErrCantReadProtocolVersionAndCommand,
desc: "proxy something crlf",
reader: newBufioReader([]byte("PROXY SOMETHING" + crlf)),
expectedError: ErrCantReadAddressFamilyAndProtocol,
},
{
newBufioReader([]byte("PROXY TCP4 " + TCP4AddressesAndPorts)),
ErrCantReadProtocolVersionAndCommand,
desc: "incomplete signature TCP4",
reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndPorts)),
expectedError: ErrLineMustEndWithCrlf,
},
{
newBufioReader([]byte("PROXY TCP6 " + TCP4AddressesAndPorts + crlf)),
ErrInvalidAddress,
desc: "TCP6 with IPv4 addresses",
reader: newBufioReader([]byte("PROXY TCP6 " + IPv4AddressesAndPorts + crlf)),
expectedError: ErrInvalidAddress,
},
{
newBufioReader([]byte("PROXY TCP4 " + TCP6AddressesAndPorts + crlf)),
ErrInvalidAddress,
desc: "TCP4 with IPv6 addresses",
reader: newBufioReader([]byte("PROXY TCP4 " + IPv6AddressesAndPorts + crlf)),
expectedError: ErrInvalidAddress,
},
// PROXY TCP IPv4
{newBufioReader([]byte("PROXY TCP4 " + TCP4AddressesAndInvalidPorts + crlf)),
ErrInvalidPortNumber,
{
desc: "TCP4 with invalid port",
reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndInvalidPorts + crlf)),
expectedError: ErrInvalidPortNumber,
},
}

func TestReadV1Invalid(t *testing.T) {
for _, tt := range invalidParseV1Tests {
if _, err := Read(tt.reader); err != tt.expectedError {
t.Fatalf("TestReadV1Invalid: expected %s, actual %s", tt.expectedError, err.Error())
}
t.Run(tt.desc, func(t *testing.T) {
if _, err := Read(tt.reader); err != tt.expectedError {
t.Fatalf("expected %s, actual %s", tt.expectedError, err.Error())
}
})
}
}

var validParseAndWriteV1Tests = []struct {
desc string
reader *bufio.Reader
expectedHeader *Header
}{
{
bufio.NewReader(strings.NewReader(fixtureTCP4V1)),
&Header{
desc: "TCP4",
reader: bufio.NewReader(strings.NewReader(fixtureTCP4V1)),
expectedHeader: &Header{
Version: 1,
Command: PROXY,
TransportProtocol: TCPv4,
Expand All @@ -74,47 +99,74 @@ var validParseAndWriteV1Tests = []struct {
},
},
{
bufio.NewReader(strings.NewReader(fixtureTCP6V1)),
&Header{
desc: "TCP6",
reader: bufio.NewReader(strings.NewReader(fixtureTCP6V1)),
expectedHeader: &Header{
Version: 1,
Command: PROXY,
TransportProtocol: TCPv6,
SourceAddr: v6addr,
DestinationAddr: v6addr,
},
},
{
desc: "unknown",
reader: bufio.NewReader(strings.NewReader(fixtureUnknown)),
expectedHeader: &Header{
Version: 1,
Command: PROXY,
TransportProtocol: UNSPEC,
SourceAddr: nil,
DestinationAddr: nil,
},
},
{
desc: "unknown with addresses and ports",
reader: bufio.NewReader(strings.NewReader(fixtureUnknownWithAddresses)),
expectedHeader: &Header{
Version: 1,
Command: PROXY,
TransportProtocol: UNSPEC,
SourceAddr: nil,
DestinationAddr: nil,
},
},
}

func TestParseV1Valid(t *testing.T) {
for _, tt := range validParseAndWriteV1Tests {
header, err := Read(tt.reader)
if err != nil {
t.Fatal("TestParseV1Valid: unexpected error", err.Error())
}
if !header.EqualsTo(tt.expectedHeader) {
t.Fatalf("TestParseV1Valid: expected %#v, actual %#v", tt.expectedHeader, header)
}
t.Run(tt.desc, func(t *testing.T) {
header, err := Read(tt.reader)
if err != nil {
t.Fatal("unexpected error", err.Error())
}
if !header.EqualsTo(tt.expectedHeader) {
t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header)
}
})
}
}

func TestWriteV1Valid(t *testing.T) {
for _, tt := range validParseAndWriteV1Tests {
var b bytes.Buffer
w := bufio.NewWriter(&b)
if _, err := tt.expectedHeader.WriteTo(w); err != nil {
t.Fatal("TestWriteV1Valid: Unexpected error ", err)
}
w.Flush()
t.Run(tt.desc, func(t *testing.T) {
var b bytes.Buffer
w := bufio.NewWriter(&b)
if _, err := tt.expectedHeader.WriteTo(w); err != nil {
t.Fatal("unexpected error ", err)
}
w.Flush()

// Read written bytes to validate written header
r := bufio.NewReader(&b)
newHeader, err := Read(r)
if err != nil {
t.Fatal("TestWriteV1Valid: Unexpected error ", err)
}
// Read written bytes to validate written header
r := bufio.NewReader(&b)
newHeader, err := Read(r)
if err != nil {
t.Fatal("unexpected error ", err)
}

if !newHeader.EqualsTo(tt.expectedHeader) {
t.Fatalf("TestWriteV1Valid: expected %#v, actual %#v", tt.expectedHeader, newHeader)
}
if !newHeader.EqualsTo(tt.expectedHeader) {
t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader)
}
})
}
}
Loading