diff --git a/header.go b/header.go index 48d3fa7..db877e8 100644 --- a/header.go +++ b/header.go @@ -54,7 +54,7 @@ func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header { } h := &Header{ Version: version, - Command: PROXY, + Command: LOCAL, TransportProtocol: UNSPEC, } switch sourceAddr := sourceAddr.(type) { @@ -88,6 +88,7 @@ func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header { } } if h.TransportProtocol != UNSPEC { + h.Command = PROXY h.SourceAddr = sourceAddr h.DestinationAddr = destAddr } @@ -152,17 +153,15 @@ func (header *Header) EqualsTo(otherHeader *Header) bool { if otherHeader == nil { return false } - if header.Command.IsLocal() { - return true - } // TLVs only exist for version 2 - if header.Version == 0x02 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) { + if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) { return false } - if header.Version != otherHeader.Version || header.TransportProtocol != otherHeader.TransportProtocol { + if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol { return false } - if header.TransportProtocol == UNSPEC { + // Return early for header with LOCAL command, which contains no address information + if header.Command == LOCAL { return true } return header.SourceAddr.String() == otherHeader.SourceAddr.String() && diff --git a/header_test.go b/header_test.go index 8cb4977..5385369 100644 --- a/header_test.go +++ b/header_test.go @@ -537,7 +537,7 @@ func TestFormatInvalid(t *testing.T) { func TestHeaderProxyFromAddrs(t *testing.T) { unspec := &Header{ Version: 2, - Command: PROXY, + Command: LOCAL, TransportProtocol: UNSPEC, } diff --git a/v1.go b/v1.go index 0020d79..9ff686a 100644 --- a/v1.go +++ b/v1.go @@ -33,62 +33,68 @@ func parseVersion1(reader *bufio.Reader) (*Header, error) { } // Check full signature. tokens := strings.Split(line[:len(line)-2], separator) - transportProtocol := UNSPEC // doesn't exist in v1 but fits UNKNOWN. - if len(tokens) > 0 { - // Read address family and protocol - switch tokens[1] { - case "TCP4": - transportProtocol = TCPv4 - case "TCP6": - transportProtocol = TCPv6 - case "UNKNOWN": // no-op as UNSPEC is set already - default: - return nil, ErrCantReadAddressFamilyAndProtocol - } - - // Expect 6 tokens only when UNKNOWN is not present. - if !transportProtocol.IsUnspec() && len(tokens) < 6 { - return nil, ErrCantReadAddressFamilyAndProtocol - } - } - - // Allocation only happens when a signature is found. + + // Expect at least 2 tokens: "PROXY" and the transport protocol. + if len(tokens) < 2 { + return nil, ErrCantReadAddressFamilyAndProtocol + } + + // Read address family and protocol + var transportProtocol AddressFamilyAndProtocol + switch tokens[1] { + case "TCP4": + transportProtocol = TCPv4 + case "TCP6": + transportProtocol = TCPv6 + case "UNKNOWN": + transportProtocol = UNSPEC // doesn't exist in v1 but fits UNKNOWN + default: + return nil, ErrCantReadAddressFamilyAndProtocol + } + + // Expect 6 tokens only when UNKNOWN is not present. + if transportProtocol != UNSPEC && len(tokens) < 6 { + return nil, ErrCantReadAddressFamilyAndProtocol + } + + // When a signature is found, allocate a v1 header with Command set to PROXY. + // Command doesn't exist in v1 but set it for other parts of this library + // to rely on it for determining connection details. header := initVersion1() - // If UNKNOWN is present, set Command to LOCAL. - // Command is not present in v1 but set it for other parts of - // this library to rely on it for determining connection details. - header.Command = LOCAL // Transport protocol has been processed already. header.TransportProtocol = transportProtocol - // Only process further if UNKNOWN is not present. - if header.TransportProtocol != UNSPEC { - // Read addresses and ports - sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2]) - if err != nil { - return nil, err - } - destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3]) - if err != nil { - return nil, err - } - sourcePort, err := parseV1PortNumber(tokens[4]) - if err != nil { - return nil, err - } - destPort, err := parseV1PortNumber(tokens[5]) - if err != nil { - return nil, err - } - header.SourceAddr = &net.TCPAddr{ - IP: sourceIP, - Port: sourcePort, - } - header.DestinationAddr = &net.TCPAddr{ - IP: destIP, - Port: destPort, - } + // When UNKNOWN, set the command to LOCAL and return early + if header.TransportProtocol == UNSPEC { + header.Command = LOCAL + return header, nil + } + + // Otherwise, continue to read addresses and ports + sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2]) + if err != nil { + return nil, err + } + destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3]) + if err != nil { + return nil, err + } + sourcePort, err := parseV1PortNumber(tokens[4]) + if err != nil { + return nil, err + } + destPort, err := parseV1PortNumber(tokens[5]) + if err != nil { + return nil, err + } + header.SourceAddr = &net.TCPAddr{ + IP: sourceIP, + Port: sourcePort, + } + header.DestinationAddr = &net.TCPAddr{ + IP: destIP, + Port: destPort, } return header, nil diff --git a/v1_test.go b/v1_test.go index 76628b7..8d6dab1 100644 --- a/v1_test.go +++ b/v1_test.go @@ -45,6 +45,11 @@ var invalidParseV1Tests = []struct { reader: newBufioReader([]byte("PROXY " + crlf)), expectedError: ErrCantReadAddressFamilyAndProtocol, }, + { + desc: "proxy no space crlf", + reader: newBufioReader([]byte("PROXY" + crlf)), + expectedError: ErrCantReadAddressFamilyAndProtocol, + }, { desc: "proxy something crlf", reader: newBufioReader([]byte("PROXY SOMETHING" + crlf)), @@ -114,7 +119,7 @@ var validParseAndWriteV1Tests = []struct { reader: bufio.NewReader(strings.NewReader(fixtureUnknown)), expectedHeader: &Header{ Version: 1, - Command: PROXY, + Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, @@ -125,7 +130,7 @@ var validParseAndWriteV1Tests = []struct { reader: bufio.NewReader(strings.NewReader(fixtureUnknownWithAddresses)), expectedHeader: &Header{ Version: 1, - Command: PROXY, + Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, diff --git a/v2.go b/v2.go index 1dd8db6..74bf3f0 100644 --- a/v2.go +++ b/v2.go @@ -205,7 +205,6 @@ func (header *Header) formatVersion2() ([]byte, error) { addrDst = formatUnixName(destAddr.Name) } - // if addrSrc == nil || addrDst == nil { return nil, ErrInvalidAddress } diff --git a/version_cmd.go b/version_cmd.go index 67ce303..59f2042 100644 --- a/version_cmd.go +++ b/version_cmd.go @@ -1,10 +1,16 @@ package proxyproto -// ProtocolVersionAndCommand represents proxy protocol version and command. +// ProtocolVersionAndCommand represents the command in proxy protocol v2. +// Command doesn't exist in v1 but it should be set since other parts of +// this library may rely on it for determining connection details. type ProtocolVersionAndCommand byte const ( + // LOCAL represents the LOCAL command in v2 or UNKNOWN transport in v1, + // in which case no address information is expected. LOCAL ProtocolVersionAndCommand = '\x20' + // PROXY represents the PROXY command in v2 or transport is not UNKNOWN in v1, + // in which case valid local/remote address and port information is expected. PROXY ProtocolVersionAndCommand = '\x21' ) @@ -13,17 +19,19 @@ var supportedCommand = map[ProtocolVersionAndCommand]bool{ PROXY: true, } -// IsLocal returns true if the protocol version is \x2 and command is LOCAL, false otherwise. +// IsLocal returns true if the command in v2 is LOCAL or the transport in v1 is UNKNOWN, +// i.e. when no address information is expected, false otherwise. func (pvc ProtocolVersionAndCommand) IsLocal() bool { - return 0x20 == pvc&0xF0 && 0x00 == pvc&0x0F + return LOCAL == pvc } -// IsProxy returns true if the protocol version is \x2 and command is PROXY, false otherwise. +// IsProxy returns true if the command in v2 is PROXY or the transport in v1 is not UNKNOWN, +// i.e. when valid local/remote address and port information is expected, false otherwise. func (pvc ProtocolVersionAndCommand) IsProxy() bool { - return 0x20 == pvc&0xF0 && 0x01 == pvc&0x0F + return PROXY == pvc } -// IsUnspec returns true if the protocol version or command is unspecified, false otherwise. +// IsUnspec returns true if the command is unspecified, false otherwise. func (pvc ProtocolVersionAndCommand) IsUnspec() bool { return !(pvc.IsLocal() || pvc.IsProxy()) }