Skip to content

Commit 5ddbd28

Browse files
committed
Merge branch 'compress'
2 parents 404e6b1 + 6c51b25 commit 5ddbd28

File tree

3 files changed

+65
-29
lines changed

3 files changed

+65
-29
lines changed

compression.go

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,32 @@ import (
1313
)
1414

1515
var (
16-
flateWriterPool = sync.Pool{}
16+
flateWriterPool = sync.Pool{New: func() interface{} {
17+
fw, _ := flate.NewWriter(nil, 3)
18+
return fw
19+
}}
20+
flateReaderPool = sync.Pool{New: func() interface{} {
21+
return flate.NewReader(nil)
22+
}}
1723
)
1824

19-
func decompressNoContextTakeover(r io.Reader) io.Reader {
25+
func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
2026
const tail =
2127
// Add four bytes as specified in RFC
2228
"\x00\x00\xff\xff" +
2329
// Add final block to squelch unexpected EOF error from flate reader.
2430
"\x01\x00\x00\xff\xff"
25-
return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
31+
32+
fr, _ := flateReaderPool.Get().(io.ReadCloser)
33+
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
34+
return &flateReadWrapper{fr}
2635
}
2736

28-
func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
37+
func compressNoContextTakeover(w io.WriteCloser) io.WriteCloser {
2938
tw := &truncWriter{w: w}
30-
i := flateWriterPool.Get()
31-
var fw *flate.Writer
32-
var err error
33-
if i == nil {
34-
fw, err = flate.NewWriter(tw, 3)
35-
} else {
36-
fw = i.(*flate.Writer)
37-
fw.Reset(tw)
38-
}
39-
return &flateWrapper{fw: fw, tw: tw}, err
39+
fw, _ := flateWriterPool.Get().(*flate.Writer)
40+
fw.Reset(tw)
41+
return &flateWriteWrapper{fw: fw, tw: tw}
4042
}
4143

4244
// truncWriter is an io.Writer that writes all but the last four bytes of the
@@ -75,19 +77,19 @@ func (w *truncWriter) Write(p []byte) (int, error) {
7577
return n + nn, err
7678
}
7779

78-
type flateWrapper struct {
80+
type flateWriteWrapper struct {
7981
fw *flate.Writer
8082
tw *truncWriter
8183
}
8284

83-
func (w *flateWrapper) Write(p []byte) (int, error) {
85+
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
8486
if w.fw == nil {
8587
return 0, errWriteClosed
8688
}
8789
return w.fw.Write(p)
8890
}
8991

90-
func (w *flateWrapper) Close() error {
92+
func (w *flateWriteWrapper) Close() error {
9193
if w.fw == nil {
9294
return errWriteClosed
9395
}
@@ -103,3 +105,31 @@ func (w *flateWrapper) Close() error {
103105
}
104106
return err2
105107
}
108+
109+
type flateReadWrapper struct {
110+
fr io.ReadCloser
111+
}
112+
113+
func (r *flateReadWrapper) Read(p []byte) (int, error) {
114+
if r.fr == nil {
115+
return 0, io.ErrClosedPipe
116+
}
117+
n, err := r.fr.Read(p)
118+
if err == io.EOF {
119+
// Preemptively place the reader back in the pool. This helps with
120+
// scenarios where the application does not call NextReader() soon after
121+
// this final read.
122+
r.Close()
123+
}
124+
return n, err
125+
}
126+
127+
func (r *flateReadWrapper) Close() error {
128+
if r.fr == nil {
129+
return io.ErrClosedPipe
130+
}
131+
err := r.fr.Close()
132+
flateReaderPool.Put(r.fr)
133+
r.fr = nil
134+
return err
135+
}

conn.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,10 @@ type Conn struct {
235235
writeErr error
236236

237237
enableWriteCompression bool
238-
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
238+
newCompressionWriter func(io.WriteCloser) io.WriteCloser
239239

240240
// Read fields
241+
reader io.ReadCloser // the current reader returned to the application
241242
readErr error
242243
br *bufio.Reader
243244
readRemaining int64 // bytes remaining in current frame.
@@ -253,7 +254,7 @@ type Conn struct {
253254
messageReader *messageReader // the current low-level reader
254255

255256
readDecompress bool // whether last read frame had RSV1 set
256-
newDecompressionReader func(io.Reader) io.Reader
257+
newDecompressionReader func(io.Reader) io.ReadCloser
257258
}
258259

259260
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@@ -443,11 +444,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
443444
}
444445
c.writer = mw
445446
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
446-
w, err := c.newCompressionWriter(c.writer)
447-
if err != nil {
448-
c.writer = nil
449-
return nil, err
450-
}
447+
w := c.newCompressionWriter(c.writer)
451448
mw.compress = true
452449
c.writer = w
453450
}
@@ -855,6 +852,11 @@ func (c *Conn) handleProtocolError(message string) error {
855852
// permanent. Once this method returns a non-nil error, all subsequent calls to
856853
// this method return the same error.
857854
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
855+
// Close previous reader, only relevant for decompression.
856+
if c.reader != nil {
857+
c.reader.Close()
858+
c.reader = nil
859+
}
858860

859861
c.messageReader = nil
860862
c.readLength = 0
@@ -867,11 +869,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
867869
}
868870
if frameType == TextMessage || frameType == BinaryMessage {
869871
c.messageReader = &messageReader{c}
870-
var r io.Reader = c.messageReader
872+
c.reader = c.messageReader
871873
if c.readDecompress {
872-
r = c.newDecompressionReader(r)
874+
c.reader = c.newDecompressionReader(c.reader)
873875
}
874-
return frameType, r, nil
876+
return frameType, c.reader, nil
875877
}
876878
}
877879

@@ -933,6 +935,10 @@ func (r *messageReader) Read(b []byte) (int, error) {
933935
return 0, err
934936
}
935937

938+
func (r *messageReader) Close() error {
939+
return nil
940+
}
941+
936942
// ReadMessage is a helper method for getting a reader using NextReader and
937943
// reading from that reader to a buffer.
938944
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {

doc.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@
150150
// application's responsibility to check the Origin header before calling
151151
// Upgrade.
152152
//
153-
// Compression [Experimental]
153+
// Compression
154154
//
155155
// Per message compression extensions (RFC 7692) are experimentally supported
156156
// by this package in a limited capacity. Setting the EnableCompression option
@@ -162,7 +162,7 @@
162162
// Per message compression of messages written to a connection can be enabled
163163
// or disabled by calling the corresponding Conn method:
164164
//
165-
// conn.EnableWriteCompression(true)
165+
// conn.EnableWriteCompression(true)
166166
//
167167
// Currently this package does not support compression with "context takeover".
168168
// This means that messages must be compressed and decompressed in isolation,

0 commit comments

Comments
 (0)