Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,31 @@ func ipFromAddr(upstream net.Addr) (net.IP, error) {
return upstreamIP, nil
}

// IgnoreProxyHeaderNotOnInterface retuns a ConnPolicyFunc which can be used to
// TrustProxyHeaderFrom returns a ConnPolicyFunc which can be used to decide
// whether to use or reject PROXY headers based on the source IP of the
// connection. This policy ensures that only trusted sources can set the PROXY
// header. Connections from IPs not in the trusted list will be rejected.
func TrustProxyHeaderFrom(trustedIPs ...net.IP) ConnPolicyFunc {
return func(connOpts ConnPolicyOptions) (Policy, error) {
ip, err := ipFromAddr(connOpts.Upstream)
if err != nil {
return REJECT, err
}

for _, trustedIP := range trustedIPs {
if trustedIP.Equal(ip) {
return USE, nil
}
}

return REJECT, nil
}
}

// IgnoreProxyHeaderNotOnInterface returns a ConnPolicyFunc which can be used to
// decide whether to use or ignore PROXY headers depending on the connection
// being made on a specific interface. This policy can be used when the server
// is bound to multiple interfaces but wants to allow on only one interface.
// being made on specific interfaces. This policy can be used when the server
// is bound to multiple interfaces but wants to allow on one or more interfaces.
func IgnoreProxyHeaderNotOnInterface(allowedIP net.IP) ConnPolicyFunc {
return func(connOpts ConnPolicyOptions) (Policy, error) {
ip, err := ipFromAddr(connOpts.Downstream)
Expand Down
41 changes: 39 additions & 2 deletions policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,44 @@ func TestSkipProxyHeaderForCIDR(t *testing.T) {
}
}

func TestTrustProxyHeaderFrom(t *testing.T) {
upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738")
if err != nil {
t.Fatalf("err: %v", err)
}

var cases = []struct {
name string
policy ConnPolicyFunc
upstreamAddr net.Addr
expectedPolicy Policy
expectError bool
}{
{"reject header from untrusted source", TrustProxyHeaderFrom(net.ParseIP("192.0.2.1")), upstream, REJECT, false},
{"use header from trusted load balancer", TrustProxyHeaderFrom(net.ParseIP("10.0.0.3")), upstream, USE, false},
{"use header when source matches any trusted IP", TrustProxyHeaderFrom(net.ParseIP("192.0.2.1"), net.ParseIP("10.0.0.3")), upstream, USE, false},
{"invalid address should return error", TrustProxyHeaderFrom(net.ParseIP("10.0.0.3")), failingAddr{}, REJECT, true},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
policy, err := tc.policy(ConnPolicyOptions{
Upstream: tc.upstreamAddr,
})
if !tc.expectError && err != nil {
t.Fatalf("err: %v", err)
}
if tc.expectError && err == nil {
t.Fatal("Expected error, got none")
}

if policy != tc.expectedPolicy {
t.Fatalf("Expected policy %v, got %v", tc.expectedPolicy, policy)
}
})
}
}

func TestIgnoreProxyHeaderNotOnInterface(t *testing.T) {
downstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738")
if err != nil {
Expand All @@ -225,7 +263,7 @@ func TestIgnoreProxyHeaderNotOnInterface(t *testing.T) {
expectedPolicy Policy
expectError bool
}{
{"ignore header for requests non on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("192.0.2.1")), downstream, IGNORE, false},
{"ignore header for requests not on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("192.0.2.1")), downstream, IGNORE, false},
{"use headers for requests on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("10.0.0.3")), downstream, USE, false},
{"invalid address should return error", IgnoreProxyHeaderNotOnInterface(net.ParseIP("10.0.0.3")), failingAddr{}, REJECT, true},
}
Expand All @@ -247,5 +285,4 @@ func TestIgnoreProxyHeaderNotOnInterface(t *testing.T) {
}
})
}

}
Loading