Skip to content

Commit f446ee1

Browse files
committed
Add support for validating the downstream ip of the connection
1 parent e5b291b commit f446ee1

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

protocol.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ var DefaultReadHeaderTimeout = 10 * time.Second
2525
type Listener struct {
2626
Listener net.Listener
2727
Policy PolicyFunc
28+
DownstreamPolicy PolicyFunc
2829
ValidateHeader Validator
2930
ReadHeaderTimeout time.Duration
3031
}
@@ -79,6 +80,18 @@ func (p *Listener) Accept() (net.Conn, error) {
7980
return conn, nil
8081
}
8182
}
83+
if p.DownstreamPolicy != nil {
84+
proxyHeaderPolicy, err = p.DownstreamPolicy(conn.LocalAddr())
85+
if err != nil {
86+
// can't decide the policy, we can't accept the connection
87+
conn.Close()
88+
return nil, err
89+
}
90+
// Handle a connection as a regular one
91+
if proxyHeaderPolicy == SKIP {
92+
return conn, nil
93+
}
94+
}
8295

8396
newConn := NewConn(
8497
conn,

protocol_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,91 @@ func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {
857857
}
858858
}
859859

860+
func TestIgnoreUpstreamPolicyIgnoresIpFromProxyHeader(t *testing.T) {
861+
l, err := net.Listen("tcp", "127.0.0.1:0")
862+
if err != nil {
863+
t.Fatalf("err: %v", err)
864+
}
865+
866+
policyFunc := func(downstream net.Addr) (Policy, error) { return IGNORE, nil }
867+
868+
pl := &Listener{Listener: l, DownstreamPolicy: policyFunc}
869+
870+
cliResult := make(chan error)
871+
go func() {
872+
conn, err := net.Dial("tcp", pl.Addr().String())
873+
if err != nil {
874+
cliResult <- err
875+
return
876+
}
877+
defer conn.Close()
878+
879+
// Write out the header!
880+
header := &Header{
881+
Version: 2,
882+
Command: PROXY,
883+
TransportProtocol: TCPv4,
884+
SourceAddr: &net.TCPAddr{
885+
IP: net.ParseIP("10.1.1.1"),
886+
Port: 1000,
887+
},
888+
DestinationAddr: &net.TCPAddr{
889+
IP: net.ParseIP("20.2.2.2"),
890+
Port: 2000,
891+
},
892+
}
893+
if _, err := header.WriteTo(conn); err != nil {
894+
cliResult <- err
895+
return
896+
}
897+
898+
if _, err := conn.Write([]byte("ping")); err != nil {
899+
cliResult <- err
900+
return
901+
}
902+
903+
recv := make([]byte, 4)
904+
if _, err = conn.Read(recv); err != nil {
905+
cliResult <- err
906+
return
907+
}
908+
if !bytes.Equal(recv, []byte("pong")) {
909+
cliResult <- fmt.Errorf("bad: %v", recv)
910+
return
911+
}
912+
913+
close(cliResult)
914+
}()
915+
916+
conn, err := pl.Accept()
917+
if err != nil {
918+
t.Fatalf("err: %v", err)
919+
}
920+
defer conn.Close()
921+
922+
recv := make([]byte, 4)
923+
if _, err = conn.Read(recv); err != nil {
924+
t.Fatalf("err: %v", err)
925+
}
926+
if !bytes.Equal(recv, []byte("ping")) {
927+
t.Fatalf("bad: %v", recv)
928+
}
929+
930+
if _, err := conn.Write([]byte("pong")); err != nil {
931+
t.Fatalf("err: %v", err)
932+
}
933+
934+
// Check the remote addr
935+
addr := conn.RemoteAddr().(*net.TCPAddr)
936+
if addr.IP.String() != "127.0.0.1" {
937+
t.Fatalf("bad: %v", addr)
938+
}
939+
err = <-cliResult
940+
if err != nil {
941+
t.Fatalf("client error: %v", err)
942+
}
943+
}
944+
860945
func Test_AllOptionsAreRecognized(t *testing.T) {
861946
recognizedOpt1 := false
862947
opt1 := func(c *Conn) {

0 commit comments

Comments
 (0)