Skip to content

Commit aff9d13

Browse files
committed
Switch to stateless compression with klauspost/compress
1 parent c5b0a00 commit aff9d13

File tree

7 files changed

+28
-40
lines changed

7 files changed

+28
-40
lines changed

compress_notjs.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ func newSlidingWindow(n int) *slidingWindow {
125125
}
126126
}
127127

128+
func (w *slidingWindow) getBuf() []byte {
129+
if w == nil {
130+
return nil
131+
}
132+
return w.buf
133+
}
134+
128135
func (w *slidingWindow) write(p []byte) {
129136
if len(p) >= cap(w.buf) {
130137
w.buf = w.buf[:cap(w.buf)]

conn_notjs.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ func (c *Conn) close(err error) {
141141
c.writeFrameMu.Lock(context.Background())
142142
putBufioWriter(c.bw)
143143
}
144-
c.msgWriter.close()
145144

146145
c.msgReader.close()
147146
if c.client {

conn_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
351351
ctx, cancel := context.WithCancel(tt.ctx)
352352

353353
discardLoopErr := xsync.Go(func() error {
354+
defer c.Close(websocket.StatusInternalError, "")
355+
354356
for {
355357
_, _, err := c.Read(ctx)
356358
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ require (
99
github.com/golang/protobuf v1.3.3
1010
github.com/google/go-cmp v0.4.0
1111
github.com/gorilla/websocket v1.4.1
12+
github.com/klauspost/compress v1.10.0
1213
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
1314
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543
1415
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
1010
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
1111
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
1212
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
13+
github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y=
14+
github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
1315
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
1416
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
1517
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=

read.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,7 @@ func (mr *msgReader) resetFlate() {
9191
mr.dict = newSlidingWindow(32768)
9292
}
9393

94-
if mr.flateContextTakeover() {
95-
mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf)
96-
} else {
97-
mr.flateReader = getFlateReader(readerFunc(mr.read), nil)
98-
}
94+
mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.getBuf())
9995
mr.limitReader.r = mr.flateReader
10096
mr.flateTail.Reset(deflateMessageTail)
10197
}

write.go

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@ package websocket
44

55
import (
66
"bufio"
7-
"compress/flate"
87
"context"
98
"crypto/rand"
109
"encoding/binary"
1110
"io"
12-
"sync"
1311
"time"
1412

13+
"github.com/klauspost/compress/flate"
1514
"golang.org/x/xerrors"
1615

1716
"nhooyr.io/websocket/internal/errd"
@@ -51,16 +50,15 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
5150
type msgWriter struct {
5251
c *Conn
5352

54-
mu *mu
55-
writeMu sync.Mutex
53+
mu *mu
5654

5755
ctx context.Context
5856
opcode opcode
5957
closed bool
6058
flate bool
6159

62-
trimWriter *trimLastFourBytesWriter
63-
flateWriter *flate.Writer
60+
trimWriter *trimLastFourBytesWriter
61+
dict *slidingWindow
6462
}
6563

6664
func newMsgWriter(c *Conn) *msgWriter {
@@ -72,16 +70,16 @@ func newMsgWriter(c *Conn) *msgWriter {
7270
}
7371

7472
func (mw *msgWriter) ensureFlate() {
73+
if mw.flateContextTakeover() && mw.dict == nil {
74+
mw.dict = newSlidingWindow(8192)
75+
}
76+
7577
if mw.trimWriter == nil {
7678
mw.trimWriter = &trimLastFourBytesWriter{
7779
w: writerFunc(mw.write),
7880
}
7981
}
8082

81-
if mw.flateWriter == nil {
82-
mw.flateWriter = getFlateWriter(mw.trimWriter)
83-
}
84-
8583
mw.flate = true
8684
}
8785

@@ -138,20 +136,10 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
138136
return nil
139137
}
140138

141-
func (mw *msgWriter) returnFlateWriter() {
142-
if mw.flateWriter != nil {
143-
putFlateWriter(mw.flateWriter)
144-
mw.flateWriter = nil
145-
}
146-
}
147-
148139
// Write writes the given bytes to the WebSocket connection.
149140
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
150141
defer errd.Wrap(&err, "failed to write")
151142

152-
mw.writeMu.Lock()
153-
defer mw.writeMu.Unlock()
154-
155143
if mw.closed {
156144
return 0, xerrors.New("cannot use closed writer")
157145
}
@@ -165,7 +153,11 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
165153
}
166154

167155
if mw.flate {
168-
return mw.flateWriter.Write(p)
156+
err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.getBuf())
157+
if mw.flateContextTakeover() {
158+
mw.dict.write(p)
159+
}
160+
return len(p), err
169161
}
170162

171163
return mw.write(p)
@@ -184,17 +176,14 @@ func (mw *msgWriter) write(p []byte) (int, error) {
184176
func (mw *msgWriter) Close() (err error) {
185177
defer errd.Wrap(&err, "failed to close writer")
186178

187-
mw.writeMu.Lock()
188-
defer mw.writeMu.Unlock()
189-
190179
if mw.closed {
191180
return xerrors.New("cannot use closed writer")
192181
}
193182

194183
if mw.flate {
195-
err = mw.flateWriter.Flush()
184+
err = flate.StatelessDeflate(mw.trimWriter, nil, true, mw.dict.getBuf())
196185
if err != nil {
197-
return xerrors.Errorf("failed to flush flate writer: %w", err)
186+
return xerrors.Errorf("failed to flush flate: %w", err)
198187
}
199188
}
200189

@@ -207,18 +196,10 @@ func (mw *msgWriter) Close() (err error) {
207196
return xerrors.Errorf("failed to write fin frame: %w", err)
208197
}
209198

210-
if mw.flate && !mw.flateContextTakeover() {
211-
mw.returnFlateWriter()
212-
}
213199
mw.mu.Unlock()
214200
return nil
215201
}
216202

217-
func (mw *msgWriter) close() {
218-
mw.writeMu.Lock()
219-
mw.returnFlateWriter()
220-
}
221-
222203
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
223204
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
224205
defer cancel()

0 commit comments

Comments
 (0)