Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,23 @@ import (
// Stuff to be used in both versions tests.

const (
NO_PROTOCOL = "There is no spoon"
IP4_ADDR = "127.0.0.1"
IP6_ADDR = "::1"
IP6_LONG_ADDR = "1234:5678:9abc:def0:cafe:babe:dead:2bad"
PORT = 65533
INVALID_PORT = 99999
NO_PROTOCOL = "There is no spoon"
IP4_ADDR = "127.0.0.1"
IP6_ADDR = "::1"
IP6_LONG_ADDR = "1234:5678:9abc:def0:cafe:babe:dead:2bad"
IP6_COMPAT_ADDR = "0:0:0:0:0:ffff:7f00:1"
PORT = 65533
INVALID_PORT = 99999
)

var (
v4ip = net.ParseIP(IP4_ADDR).To4()
v6ip = net.ParseIP(IP6_ADDR).To16()
v4ip = net.ParseIP(IP4_ADDR).To4()
v6ip = net.ParseIP(IP6_ADDR).To16()
v6CompatIP = net.ParseIP(IP4_ADDR).To16()

v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: PORT}
v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: PORT}
v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: PORT}
v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: PORT}
v6CompatAddr net.Addr = &net.TCPAddr{IP: v6CompatIP, Port: PORT}

v4UDPAddr net.Addr = &net.UDPAddr{IP: v4ip, Port: PORT}
v6UDPAddr net.Addr = &net.UDPAddr{IP: v6ip, Port: PORT}
Expand Down
61 changes: 53 additions & 8 deletions v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,27 @@ func (header *Header) formatVersion1() ([]byte, error) {
sourceIP = sourceIP.To16()
destIP = destIP.To16()
}

if sourceIP == nil || destIP == nil {
return nil, ErrInvalidAddress
}

ipToString := func(ip net.IP) string {
if header.TransportProtocol == TCPv6 && ip.To4() != nil {
return fmt.Sprintf("::FFFF:%s", ip.String())
}

return ip.String()
}

buf := bytes.NewBuffer(make([]byte, 0, 108))
buf.Write(SIGV1)
buf.WriteString(separator)
buf.WriteString(proto)
buf.WriteString(separator)
buf.WriteString(sourceIP.String())
buf.WriteString(ipToString(sourceIP))
buf.WriteString(separator)
buf.WriteString(destIP.String())
buf.WriteString(ipToString(destIP))
buf.WriteString(separator)
buf.WriteString(strconv.Itoa(sourceAddr.Port))
buf.WriteString(separator)
Expand All @@ -221,11 +230,47 @@ func parseV1PortNumber(portStr string) (int, error) {
return port, nil
}

func parseV1IPAddress(protocol AddressFamilyAndProtocol, addrStr string) (addr net.IP, err error) {
addr = net.ParseIP(addrStr)
tryV4 := addr.To4()
if (protocol == TCPv4 && tryV4 == nil) || (protocol == TCPv6 && tryV4 != nil) {
err = ErrInvalidAddress
func parseV1IPAddress(protocol AddressFamilyAndProtocol, addrStr string) (net.IP, error) {

addr := net.ParseIP(addrStr)
if addr == nil {
return nil, ErrInvalidAddress
}

if protocol == TCPv4 && addr.To4() == nil {
return nil, ErrInvalidAddress
} else if protocol == TCPv4 {
return addr, nil
}

if protocol == TCPv6 && !strings.Contains(addr.String(), ".") {
return addr, nil
}
return

// This check is not foolproof, but it's all we can using the net/IP library
if protocol == TCPv6 && !strings.Contains(strings.ToLower(addrStr), ":ffff:") {
return nil, ErrInvalidAddress
}

if protocol == TCPv6 && isIpv4InIpv6(addr) {
return addr, nil
}

return nil, ErrInvalidAddress
}

func isIpv4InIpv6(ip net.IP) bool {
isZeros := func(p net.IP) bool {
for i := 0; i < len(p); i++ {
if p[i] != 0 {
return false
}
}
return true
}

return len(ip) == net.IPv6len &&
isZeros(ip[0:10]) &&
ip[10] == 0xff &&
ip[11] == 0xff
}
16 changes: 15 additions & 1 deletion v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var (
IPv4AddressesAndInvalidPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(INVALID_PORT), strconv.Itoa(INVALID_PORT)}, separator)
IPv6AddressesAndPorts = strings.Join([]string{IP6_ADDR, IP6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)
IPv6LongAddressesAndPorts = strings.Join([]string{IP6_LONG_ADDR, IP6_LONG_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)
TCP6CompatAddressesAndPorts = strings.Join([]string{IP6_COMPAT_ADDR, IP6_COMPAT_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)

fixtureTCP4V1 = "PROXY TCP4 " + IPv4AddressesAndPorts + crlf + "GET /"
fixtureTCP6V1 = "PROXY TCP6 " + IPv6AddressesAndPorts + crlf + "GET /"
Expand All @@ -24,6 +25,8 @@ var (

fixtureUnknown = "PROXY UNKNOWN" + crlf
fixtureUnknownWithAddresses = "PROXY UNKNOWN " + IPv4AddressesAndInvalidPorts + crlf

fixtureTCP6CompatV1 = "PROXY TCP6 " + TCP6CompatAddressesAndPorts + crlf + "GET /"
)

var invalidParseV1Tests = []struct {
Expand Down Expand Up @@ -92,7 +95,7 @@ func TestReadV1Invalid(t *testing.T) {
for _, tt := range invalidParseV1Tests {
t.Run(tt.desc, func(t *testing.T) {
if _, err := Read(tt.reader); err != tt.expectedError {
t.Fatalf("expected %s, actual %v", tt.expectedError, err)
t.Fatalf("TestReadV1Invalid: expected %s, actual %v", tt.expectedError, err)
}
})
}
Expand Down Expand Up @@ -147,6 +150,17 @@ var validParseAndWriteV1Tests = []struct {
DestinationAddr: nil,
},
},
{
desc: "tcp6 compat v1",
reader: bufio.NewReader(strings.NewReader(fixtureTCP6CompatV1)),
expectedHeader: &Header{
Version: 1,
Command: PROXY,
TransportProtocol: TCPv6,
SourceAddr: v6CompatAddr,
DestinationAddr: v6CompatAddr,
},
},
}

func TestParseV1Valid(t *testing.T) {
Expand Down