Skip to content

Commit a55d19c

Browse files
authored
Merge pull request #21 from igorwwwwwwwwwwwwwwwwwwww/patch-1
header: return err from reader.Peek()
2 parents 62dfc14 + fb3158b commit a55d19c

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

header.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,34 @@ func (header *Header) TLVs() ([]TLV, error) {
123123
// Also, this operation will block until enough bytes are available for peeking.
124124
func Read(reader *bufio.Reader) (*Header, error) {
125125
// In order to improve speed for small non-PROXYed packets, take a peek at the first byte alone.
126-
if b1, err := reader.Peek(1); err == nil && (bytes.Equal(b1[:1], SIGV1[:1]) || bytes.Equal(b1[:1], SIGV2[:1])) {
127-
if signature, err := reader.Peek(5); err == nil && bytes.Equal(signature[:5], SIGV1) {
126+
b1, err := reader.Peek(1)
127+
if err != nil {
128+
if err == io.EOF {
129+
return nil, ErrNoProxyProtocol
130+
}
131+
return nil, err
132+
}
133+
134+
if bytes.Equal(b1[:1], SIGV1[:1]) || bytes.Equal(b1[:1], SIGV2[:1]) {
135+
signature, err := reader.Peek(5)
136+
if err != nil {
137+
if err == io.EOF {
138+
return nil, ErrNoProxyProtocol
139+
}
140+
return nil, err
141+
}
142+
if bytes.Equal(signature[:5], SIGV1) {
128143
return parseVersion1(reader)
129-
} else if signature, err := reader.Peek(12); err == nil && bytes.Equal(signature[:12], SIGV2) {
144+
}
145+
146+
signature, err = reader.Peek(12)
147+
if err != nil {
148+
if err == io.EOF {
149+
return nil, ErrNoProxyProtocol
150+
}
151+
return nil, err
152+
}
153+
if bytes.Equal(signature[:12], SIGV2) {
130154
return parseVersion2(reader)
131155
}
132156
}

header_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package proxyproto
33
import (
44
"bufio"
55
"bytes"
6+
"errors"
67
"net"
78
"testing"
89
"time"
@@ -21,6 +22,8 @@ const (
2122
var (
2223
v4addr = net.ParseIP(IP4_ADDR).To4()
2324
v6addr = net.ParseIP(IP6_ADDR).To16()
25+
26+
errReadIntentionallyBroken = errors.New("read is intentionally broken")
2427
)
2528

2629
type timeoutReader []byte
@@ -30,6 +33,12 @@ func (t *timeoutReader) Read([]byte) (int, error) {
3033
return 0, nil
3134
}
3235

36+
type errorReader []byte
37+
38+
func (e *errorReader) Read([]byte) (int, error) {
39+
return 0, errReadIntentionallyBroken
40+
}
41+
3342
func TestReadTimeoutV1Invalid(t *testing.T) {
3443
var b timeoutReader
3544
reader := bufio.NewReader(&b)
@@ -41,6 +50,18 @@ func TestReadTimeoutV1Invalid(t *testing.T) {
4150
}
4251
}
4352

53+
func TestReadTimeoutPropagatesReadError(t *testing.T) {
54+
var e errorReader
55+
reader := bufio.NewReader(&e)
56+
_, err := ReadTimeout(reader, 50*time.Millisecond)
57+
58+
if err == nil {
59+
t.Fatalf("expected error %s", errReadIntentionallyBroken)
60+
} else if err != errReadIntentionallyBroken {
61+
t.Fatalf("expected error %s, actual %s", errReadIntentionallyBroken, err)
62+
}
63+
}
64+
4465
func TestEqualsTo(t *testing.T) {
4566
var headersEqual = []struct {
4667
this, that *Header

0 commit comments

Comments
 (0)