@@ -235,9 +235,10 @@ type Conn struct {
235235 writeErr error
236236
237237 enableWriteCompression bool
238- newCompressionWriter func (io.WriteCloser ) ( io.WriteCloser , error )
238+ newCompressionWriter func (io.WriteCloser ) io.WriteCloser
239239
240240 // Read fields
241+ reader io.ReadCloser // the current reader returned to the application
241242 readErr error
242243 br * bufio.Reader
243244 readRemaining int64 // bytes remaining in current frame.
@@ -253,7 +254,7 @@ type Conn struct {
253254 messageReader * messageReader // the current low-level reader
254255
255256 readDecompress bool // whether last read frame had RSV1 set
256- newDecompressionReader func (io.Reader ) io.Reader
257+ newDecompressionReader func (io.Reader ) io.ReadCloser
257258}
258259
259260func newConn (conn net.Conn , isServer bool , readBufferSize , writeBufferSize int ) * Conn {
@@ -443,11 +444,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
443444 }
444445 c .writer = mw
445446 if c .newCompressionWriter != nil && c .enableWriteCompression && isData (messageType ) {
446- w , err := c .newCompressionWriter (c .writer )
447- if err != nil {
448- c .writer = nil
449- return nil , err
450- }
447+ w := c .newCompressionWriter (c .writer )
451448 mw .compress = true
452449 c .writer = w
453450 }
@@ -855,6 +852,11 @@ func (c *Conn) handleProtocolError(message string) error {
855852// permanent. Once this method returns a non-nil error, all subsequent calls to
856853// this method return the same error.
857854func (c * Conn ) NextReader () (messageType int , r io.Reader , err error ) {
855+ // Close previous reader, only relevant for decompression.
856+ if c .reader != nil {
857+ c .reader .Close ()
858+ c .reader = nil
859+ }
858860
859861 c .messageReader = nil
860862 c .readLength = 0
@@ -867,11 +869,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
867869 }
868870 if frameType == TextMessage || frameType == BinaryMessage {
869871 c .messageReader = & messageReader {c }
870- var r io. Reader = c .messageReader
872+ c . reader = c .messageReader
871873 if c .readDecompress {
872- r = c .newDecompressionReader (r )
874+ c . reader = c .newDecompressionReader (c . reader )
873875 }
874- return frameType , r , nil
876+ return frameType , c . reader , nil
875877 }
876878 }
877879
@@ -933,6 +935,10 @@ func (r *messageReader) Read(b []byte) (int, error) {
933935 return 0 , err
934936}
935937
938+ func (r * messageReader ) Close () error {
939+ return nil
940+ }
941+
936942// ReadMessage is a helper method for getting a reader using NextReader and
937943// reading from that reader to a buffer.
938944func (c * Conn ) ReadMessage () (messageType int , p []byte , err error ) {
0 commit comments