Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions helper/http2/http2.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,15 @@ func (srv *Server) serveConn(baseCtx context.Context, conn net.Conn) error {
}()

ctx := baseCtx
// We don't check if srv.h1.ConnContext is nil so http.Server works the same
// with or without this middleware.
// For more info, see https://github.com/pires/go-proxyproto/pull/140/changes#r2725568706.
if connCtx := srv.h1.ConnContext(ctx, conn); connCtx != nil {
ctx = connCtx
// Mirror net/http.Server ConnContext behavior (see server.go around line 3469).
connCtx := ctx
if cc := srv.h1.ConnContext; cc != nil {
connCtx = cc(ctx, conn)
if connCtx == nil {
panic("ConnContext returned nil")
}
}
ctx = connCtx

opts := http2.ServeConnOpts{Context: ctx, BaseConfig: srv.h1}
srv.h2.ServeConn(conn, &opts)
Expand All @@ -186,6 +189,11 @@ func (srv *Server) serveConn(baseCtx context.Context, conn net.Conn) error {
}
}

// Close closes the server by closing all listeners.
func (srv *Server) Close() error {
return srv.closeListeners()
}

func (srv *Server) closeListeners() error {
srv.mu.Lock()
defer srv.mu.Unlock()
Expand Down
63 changes: 63 additions & 0 deletions helper/http2/http2_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package http2

import (
"context"
"net"
"net/http"
"testing"
"time"

"github.com/pires/go-proxyproto"
)

// TestServeConn_ConnContextReturnsNil lives in package http2 (not http2_test) so
// it can call the unexported serveConn method directly and recover the panic in
// the same goroutine, which is not possible through the public Serve API because
// Serve spawns a new goroutine per connection.
func TestServeConn_ConnContextReturnsNil(t *testing.T) {
srv := NewServer(&http.Server{
ReadHeaderTimeout: 5 * time.Second,
Handler: http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}),
ConnContext: func(_ context.Context, _ net.Conn) context.Context {
return nil
},
}, nil)

// Create a pipe and write a PROXY header with h2 ALPN to trigger the H2 path.
clientConn, serverConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
defer func() { _ = serverConn.Close() }()

header := proxyproto.Header{
Version: 2,
Command: proxyproto.LOCAL,
TransportProtocol: proxyproto.UNSPEC,
}
if err := header.SetTLVs([]proxyproto.TLV{{
Type: proxyproto.PP2_TYPE_ALPN,
Value: []byte("h2"),
}}); err != nil {
t.Fatalf("failed to set TLVs: %v", err)
}

// Write the header in a goroutine because net.Pipe is synchronous.
go func() {
_, _ = header.WriteTo(clientConn)
_ = clientConn.Close()
}()

pConn := proxyproto.NewConn(serverConn)

defer func() {
r := recover()
if r == nil {
t.Fatal("expected panic from ConnContext returning nil")
}
msg, ok := r.(string)
if !ok || msg != "ConnContext returned nil" {
t.Fatalf("expected panic message 'ConnContext returned nil', got: %v", r)
}
}()

_ = srv.serveConn(context.Background(), pConn)
}
122 changes: 119 additions & 3 deletions helper/http2/http2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import (
)

func ExampleServer() {
ln, err := net.Listen("tcp", "localhost:80")
addr := "localhost:8080"

ln, err := net.Listen("tcp", addr)
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
Expand All @@ -36,9 +38,26 @@ func ExampleServer() {
_, _ = w.Write([]byte("Hello world!\n"))
}),
}, nil)
if err := server.Serve(proxyLn); err != nil {
log.Fatalf("failed to serve: %v", err)
// Run the server in a goroutine.
go func() {
if err := server.Serve(proxyLn); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to serve: %v", err)
}
}()

resp, err := http.Get("http://" + addr)
if err != nil {
log.Fatalf("failed to perform HTTP request: %v", err)
}
if err := resp.Body.Close(); err != nil {
log.Fatalf("failed to close response body: %v", err)
}

if err := server.Close(); err != nil {
log.Fatalf("failed to close server: %v", err)
}

// Output:
}

type contextKey string
Expand Down Expand Up @@ -158,6 +177,76 @@ func TestServer_h2_tls(t *testing.T) {
}
}

func TestServer_h1_nil_ConnContext(t *testing.T) {
addr, server := newTestServerWithoutConnContext(t)
t.Cleanup(func() {
if err := server.Close(); err != nil {
t.Errorf("failed to close server: %v", err)
}
})

resp, err := http.Get("http://" + addr)
if err != nil {
t.Fatalf("failed to perform HTTP request: %v", err)
}
if err := resp.Body.Close(); err != nil {
t.Fatalf("failed to close response body: %v", err)
}
}

func TestServer_h2_nil_ConnContext(t *testing.T) {
addr, server := newTestServerWithoutConnContext(t)
t.Cleanup(func() {
if err := server.Close(); err != nil {
t.Errorf("failed to close server: %v", err)
}
})

conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer func() {
if err := conn.Close(); err != nil {
t.Errorf("failed to close connection: %v", err)
}
}()

proxyHeader := proxyproto.Header{
Version: 2,
Command: proxyproto.LOCAL,
TransportProtocol: proxyproto.UNSPEC,
}
tlvs := []proxyproto.TLV{{
Type: proxyproto.PP2_TYPE_ALPN,
Value: []byte("h2"),
}}
if err := proxyHeader.SetTLVs(tlvs); err != nil {
t.Fatalf("failed to set TLVs: %v", err)
}
if _, err := proxyHeader.WriteTo(conn); err != nil {
t.Fatalf("failed to write PROXY header: %v", err)
}

h2Conn, err := new(http2.Transport).NewClientConn(conn)
if err != nil {
t.Fatalf("failed to create HTTP connection: %v", err)
}

req, err := http.NewRequest(http.MethodGet, "http://"+addr, nil)
if err != nil {
t.Fatalf("failed to create HTTP request: %v", err)
}

resp, err := h2Conn.RoundTrip(req)
if err != nil {
t.Fatalf("failed to perform HTTP request: %v", err)
}
if err := resp.Body.Close(); err != nil {
t.Fatalf("failed to close response body: %v", err)
}
}

func newTestServer(t *testing.T) (addr string, server *http.Server) {
ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
Expand Down Expand Up @@ -239,6 +328,33 @@ func newTLSTestServer(t *testing.T) (addr string, server *http.Server) {
return ln.Addr().String(), server
}

func newTestServerWithoutConnContext(t *testing.T) (addr string, server *http.Server) {
ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}

server = &http.Server{
ReadHeaderTimeout: 5 * time.Second,
Handler: http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}),
}

h2Server := h2proxy.NewServer(server, nil)
done := make(chan error, 1)
go func() {
done <- h2Server.Serve(&proxyproto.Listener{Listener: ln})
}()

t.Cleanup(func() {
err := <-done
if err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Fatalf("failed to serve: %v", err)
}
})

return ln.Addr().String(), server
}

func testTLSConfig(t *testing.T) *tls.Config {
t.Helper()

Expand Down
Loading