Skip to content

Commit 4f677b5

Browse files
authored
Add SKIP policy to not expect a PROXY header
2 parents 2fac219 + e9fdbf2 commit 4f677b5

File tree

4 files changed

+103
-0
lines changed

4 files changed

+103
-0
lines changed

policy.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,31 @@ const (
3232
// a PROXY header is not present, subsequent reads do not. It is the task
3333
// of the code using the connection to handle that case properly.
3434
REQUIRE
35+
// SKIP accepts a connection without requiring the PROXY header
36+
// Note: an example usage can be found in the SkipProxyHeaderForCIDR
37+
// function.
38+
SKIP
3539
)
3640

41+
// SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a
42+
// connection from a skipHeaderCIDR without requiring a PROXY header, e.g.
43+
// Kubernetes pods local traffic. The def is a policy to use when an upstream
44+
// address doesn't match the skipHeaderCIDR.
45+
func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc {
46+
return func(upstream net.Addr) (Policy, error) {
47+
ip, err := ipFromAddr(upstream)
48+
if err != nil {
49+
return def, err
50+
}
51+
52+
if skipHeaderCIDR != nil && skipHeaderCIDR.Contains(ip) {
53+
return SKIP, nil
54+
}
55+
56+
return def, nil
57+
}
58+
}
59+
3760
// WithPolicy adds given policy to a connection when passed as option to NewConn()
3861
func WithPolicy(p Policy) func(*Conn) {
3962
return func(c *Conn) {

policy_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,26 @@ func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) {
188188

189189
MustStrictWhiteListPolicy([]string{"20/80"})
190190
}
191+
192+
func TestSkipProxyHeaderForCIDR(t *testing.T) {
193+
_, cidr, _ := net.ParseCIDR("192.0.2.1/24")
194+
f := SkipProxyHeaderForCIDR(cidr, REJECT)
195+
196+
upstream, _ := net.ResolveTCPAddr("tcp", "192.0.2.255:12345")
197+
policy, err := f(upstream)
198+
if err != nil {
199+
t.Fatalf("err: %v", err)
200+
}
201+
if policy != SKIP {
202+
t.Errorf("Expected a SKIP policy for the %s address", upstream)
203+
}
204+
205+
upstream, _ = net.ResolveTCPAddr("tcp", "8.8.8.8:12345")
206+
policy, err = f(upstream)
207+
if err != nil {
208+
t.Fatalf("err: %v", err)
209+
}
210+
if policy != REJECT {
211+
t.Errorf("Expected a REJECT policy for the %s address", upstream)
212+
}
213+
}

protocol.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ func (p *Listener) Accept() (net.Conn, error) {
7474
conn.Close()
7575
return nil, err
7676
}
77+
// Handle a connection as a regular one
78+
if proxyHeaderPolicy == SKIP {
79+
return conn, nil
80+
}
7781
}
7882

7983
newConn := NewConn(

protocol_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,59 @@ func TestReadingIsRefusedOnErrorWhenLocalAddrRequestedFirst(t *testing.T) {
923923
}
924924
}
925925

926+
func TestSkipProxyProtocolPolicy(t *testing.T) {
927+
l, err := net.Listen("tcp", "127.0.0.1:0")
928+
if err != nil {
929+
t.Fatalf("err: %v", err)
930+
}
931+
932+
policyFunc := func(upstream net.Addr) (Policy, error) { return SKIP, nil }
933+
934+
pl := &Listener{
935+
Listener: l,
936+
Policy: policyFunc,
937+
}
938+
939+
cliResult := make(chan error)
940+
ping := []byte("ping")
941+
go func() {
942+
conn, err := net.Dial("tcp", pl.Addr().String())
943+
if err != nil {
944+
cliResult <- err
945+
return
946+
}
947+
defer conn.Close()
948+
conn.Write(ping)
949+
close(cliResult)
950+
}()
951+
952+
conn, err := pl.Accept()
953+
if err != nil {
954+
t.Fatalf("err: %v", err)
955+
}
956+
defer conn.Close()
957+
958+
_, ok := conn.(*net.TCPConn)
959+
if !ok {
960+
t.Fatal("err: should be a tcp connection")
961+
}
962+
_ = conn.LocalAddr()
963+
recv := make([]byte, 4)
964+
_, err = conn.Read(recv)
965+
if err != nil {
966+
t.Fatalf("Unexpected read error: %v", err)
967+
}
968+
969+
if !bytes.Equal(ping, recv) {
970+
t.Fatalf("Unexpected %s data while expected %s", recv, ping)
971+
}
972+
973+
err = <-cliResult
974+
if err != nil {
975+
t.Fatalf("client error: %v", err)
976+
}
977+
}
978+
926979
func Test_ConnectionCasts(t *testing.T) {
927980
l, err := net.Listen("tcp", "127.0.0.1:0")
928981
if err != nil {

0 commit comments

Comments
 (0)